poi/
expr.rs

1use std::fmt;
2use std::ops::{Add, Sub, Mul};
3
4use super::*;
5
6/// Function expression.
7#[derive(Clone, PartialEq, PartialOrd, Debug)]
8pub enum Expr {
9    /// A symbol that is used together with symbolic knowledge.
10    Sym(Symbol),
11    /// Some function that returns a value, ignoring the argument.
12    ///
13    /// This can also be used to store values, since zero arguments is a value.
14    Ret(Value),
15    /// A binary operation on functions.
16    EOp(Op, Box<Expr>, Box<Expr>),
17    /// A tuple for more than one argument.
18    Tup(Vec<Expr>),
19    /// A list.
20    List(Vec<Expr>),
21}
22
23impl Add for Expr {
24    type Output = Expr;
25    fn add(self, other: Expr) -> Expr {app2(Add, self, other)}
26}
27
28impl Sub for Expr {
29    type Output = Expr;
30    fn sub(self, other: Expr) -> Expr {app2(Sub, self, other)}
31}
32
33impl Mul for Expr {
34    type Output = Expr;
35    fn mul(self, other: Expr) -> Expr {app2(Mul, self, other)}
36}
37
38impl Expr {
39    /// Used to display format with additional options.
40    pub fn display(
41        &self,
42        w: &mut fmt::Formatter<'_>,
43        parens: bool,
44        rule: bool,
45    ) -> std::result::Result<(), fmt::Error> {
46        match self {
47            Sym(s) => s.display(w, rule)?,
48            Ret(v) => write!(w, "{}", v)?,
49            EOp(Path, a, b) => {
50                if let Tup(b) = &**b {
51                    let parens = true;
52                    a.display(w, parens, rule)?;
53                    write!(w, "[")?;
54                    for i in 0..b.len() {
55                        if i > 0 {
56                            if i + 1 < b.len() {
57                                write!(w, " ⨯ ")?
58                            } else {
59                                write!(w, " → ")?
60                            }
61                        }
62                        b[i].display(w, true, rule)?;
63                    }
64                    write!(w, "]")?
65                } else {
66                    a.display(w, true, rule)?;
67                    write!(w, "[")?;
68                    b.display(w, false, rule)?;
69                    write!(w, "]")?;
70                }
71            }
72            EOp(Apply, a, b) => {
73                let mut r = |op: &str| -> std::result::Result<(), fmt::Error> {
74                    write!(w, "({} ", op)?;
75                    b.display(w, false, rule)?;
76                    write!(w, ")")
77                };
78                if let Sym(Neg) = **a {
79                    if parens {
80                        write!(w, "(")?;
81                    }
82                    write!(w, "-")?;
83                    b.display(w, true, rule)?;
84                    if parens {
85                        write!(w, ")")?;
86                    }
87                } else if let Sym(Not) = **a {
88                    if parens {
89                        write!(w, "(")?;
90                    }
91                    write!(w, "!")?;
92                    b.display(w, true, rule)?;
93                    if parens {
94                        write!(w, ")")?;
95                    }
96                } else if let Sym(Rty) = **a {
97                    if let Sym(_) = **b {
98                        r(":")?;
99                    }
100                } else if let Sym(Rlt) = **a {
101                    r("<")?;
102                } else if let Sym(Rle) = **a {
103                    r("<=")?;
104                } else if let Sym(Eq) = **a {
105                    r("=")?;
106                } else if let Sym(Rgt) = **a {
107                    r(">")?;
108                } else if let Sym(Rge) = **a {
109                    r(">=")?;
110                } else if let Sym(Mul) = **a {
111                    r("*")?;
112                } else if let Sym(Add) = **a {
113                    r("+")?;
114                } else if let Sym(Rsub) = **a {
115                    r("-")?;
116                } else if let Sym(Rdiv) = **a {
117                    r("/")?;
118                } else if let Sym(Rpow) = **a {
119                    r("^")?;
120                } else {
121                    if let (EOp(Apply, f, a), Sym(Pi)) = (&**a, &**b) {
122                        if let (Sym(Mul), Ret(F64(a))) = (&**f, &**a) {
123                            write!(w, "{}π", a)?;
124                            return Ok(())
125                        }
126                    }
127                    if let (EOp(Apply, f, a), Sym(Tau)) = (&**a, &**b) {
128                        if let (Sym(Mul), Ret(F64(a))) = (&**f, &**a) {
129                            write!(w, "{}τ", a)?;
130                            return Ok(())
131                        }
132                    }
133                    if let (EOp(Apply, f, a), Sym(Eps)) = (&**a, &**b) {
134                        if let (Sym(Mul), Ret(F64(a))) = (&**f, &**a) {
135                            write!(w, "{}ε", a)?;
136                            return Ok(())
137                        }
138                    }
139                    if let (EOp(Apply, f, a), Sym(Imag)) = (&**a, &**b) {
140                        if let (Sym(Mul), Ret(F64(a))) = (&**f, &**a) {
141                            write!(w, "{}𝐢", a)?;
142                            return Ok(())
143                        }
144                    }
145                    if let (EOp(Apply, f, a), Sym(Imag2)) = (&**a, &**b) {
146                        if let (Sym(Mul), Ret(F64(a))) = (&**f, &**a) {
147                            write!(w, "{}𝐢₂", a)?;
148                            return Ok(())
149                        }
150                    }
151                    if let (EOp(Apply, f, a), Sym(Imag3)) = (&**a, &**b) {
152                        if let (Sym(Mul), Ret(F64(a))) = (&**f, &**a) {
153                            write!(w, "{}𝐢₃", a)?;
154                            return Ok(())
155                        }
156                    }
157                    if let (EOp(Apply, f, b), Sym(Var(ref a))) = (&**a, &**b) {
158                        if let (Sym(Mul), Sym(Pariv)) = (&**f, &**b) {
159                            write!(w, "∂{}", a)?;
160                            return Ok(())
161                        }
162                    }
163                    if let EOp(Apply, f, a) = &**a {
164                        let mut pr = |
165                            op_txt: &str,
166                            op_sym: &Symbol
167                        | -> std::result::Result<(), fmt::Error> {
168                            if parens {write!(w, "(")?};
169                            let left = true;
170                            a.display(w, a.needs_parens(op_sym, left), rule)?;
171                            write!(w, " {} ", op_txt)?;
172                            let right = false;
173                            b.display(w, b.needs_parens(op_sym, right), rule)?;
174                            if parens {write!(w, ")")?};
175                            Ok(())
176                        };
177
178                        match **f {
179                            Sym(Add) => {
180                                pr("+", &Add)?;
181                                return Ok(())
182                            }
183                            Sym(Sub) => {
184                                pr("-", &Sub)?;
185                                return Ok(())
186                            }
187                            Sym(Mul) => {
188                                pr("*", &Mul)?;
189                                return Ok(())
190                            }
191                            Sym(Div) => {
192                                pr("/", &Div)?;
193                                return Ok(())
194                            }
195                            Sym(Rem) => {
196                                pr("%", &Rem)?;
197                                return Ok(())
198                            }
199                            Sym(Pow) => {
200                                pr("^", &Pow)?;
201                                return Ok(())
202                            }
203                            Sym(And) => {
204                                pr("&", &And)?;
205                                return Ok(())
206                            }
207                            Sym(Or) => {
208                                pr("|", &Or)?;
209                                return Ok(())
210                            }
211                            Sym(Concat) => {
212                                pr("++", &Concat)?;
213                                return Ok(())
214                            }
215                            Sym(Lt) => {
216                                pr("<", &Lt)?;
217                                return Ok(())
218                            }
219                            Sym(Le) => {
220                                pr("<=", &Le)?;
221                                return Ok(())
222                            }
223                            Sym(Eq) => {
224                                pr("=", &Eq)?;
225                                return Ok(())
226                            }
227                            Sym(Gt) => {
228                                pr(">", &Gt)?;
229                                return Ok(())
230                            }
231                            Sym(Ge) => {
232                                pr(">=", &Ge)?;
233                                return Ok(())
234                            }
235                            _ => {}
236                        }
237                    }
238
239                    if let Ret(_) = **a {
240                        write!(w, "\\")?;
241                    }
242                    let parens = true;
243                    a.display(w, parens, rule)?;
244                    if let Tup(_) = &**b {
245                        b.display(w, parens, rule)?;
246                    } else {
247                        write!(w, "(")?;
248                        b.display(w, false, rule)?;
249                        write!(w, ")")?;
250                    }
251                }
252            }
253            EOp(Constrain, a, b) => {
254                if let Ret(_) = **a {
255                    write!(w, "\\")?;
256                }
257                a.display(w, true, rule)?;
258                if let Tup(b) = &**b {
259                    write!(w, "{{")?;
260                    for i in 0..b.len() {
261                        if i > 0 {write!(w, ", ")?}
262                        b[i].display(w, false, rule)?;
263                    }
264                    write!(w, "}}")?;
265                } else {
266                    write!(w, "{{")?;
267                    b.display(w, false, rule)?;
268                    write!(w, "}}")?;
269                }
270            }
271            EOp(Compose, a, b) => {
272                if parens {
273                    write!(w, "(")?;
274                }
275                a.display(w, true, rule)?;
276                write!(w, " · ")?;
277                b.display(w, true, rule)?;
278                if parens {
279                    write!(w, ")")?;
280                }
281            }
282            EOp(Type, a, b) => {
283                if parens {
284                    write!(w, "(")?;
285                }
286                a.display(w, true, rule)?;
287                write!(w, " : ")?;
288                b.display(w, true, rule)?;
289                if parens {
290                    write!(w, ")")?;
291                }
292            }
293            Tup(b) => {
294                write!(w, "(")?;
295                for i in 0..b.len() {
296                    if i > 0 {write!(w, ", ")?}
297                    b[i].display(w, false, rule)?;
298                }
299                write!(w, ")")?;
300            }
301            List(b) => {
302                write!(w, "[")?;
303                for i in 0..b.len() {
304                    if i > 0 {write!(w, ", ")?}
305                    b[i].display(w, false, rule)?;
306                }
307                write!(w, "]")?;
308            }
309            // _ => write!(w, "{:?}", self)?,
310        }
311        Ok(())
312    }
313
314    /// Returns `true` if the expression needs parentheses, given parent operation and side.
315    pub fn needs_parens(&self, parent_op: &Symbol, left: bool) -> bool {
316        if let EOp(Apply, f, _) = self {
317            if let EOp(Apply, f, _) = &**f {
318                match &**f {
319                    Sym(x) => {
320                        if let (Some(x), Some(y)) = (x.precedence(), parent_op.precedence()) {
321                            if left {x > y} else {x >= y}
322                        } else {true}
323                    }
324                    _ => true
325                }
326            } else {
327                match &**f {
328                    Sym(x) => {
329                        if let (Some(x), Some(y)) = (x.precedence(), parent_op.precedence()) {
330                            if left {x > y} else {x >= y}
331                        } else {true}
332                    }
333                    _ => true
334                }
335            }
336        } else {
337            true
338        }
339    }
340}
341
342impl fmt::Display for Expr {
343    fn fmt(&self, w: &mut fmt::Formatter<'_>) -> std::result::Result<(), fmt::Error> {
344        let parens = false;
345        let rule = false;
346        self.display(w, parens, rule)
347    }
348}
349
350#[cfg(test)]
351mod tests {
352    use crate::*;
353    use std::fmt;
354
355    #[test]
356    fn parens() {
357        let expr = app2(Mul, app2(Mul, "a", "b"), "c");
358        assert_eq!(format!("{}", expr), "a * b * c");
359        let expr = app2(Mul, "a", app2(Mul, "b", "c"));
360        assert_eq!(format!("{}", expr), "a * (b * c)");
361        let expr = app2(Add, "a", "b");
362        assert_eq!(format!("{}", expr), "a + b");
363        let expr = app2(Mul, app2(Add, "a", "b"), "c");
364        assert_eq!(format!("{}", expr), "(a + b) * c");
365        let expr = app2(Add, app2(Add, "a", "b"), "c");
366        assert_eq!(format!("{}", expr), "a + b + c");
367        let expr = app2(Add, "a", app2(Add, "b", "c"));
368        assert_eq!(format!("{}", expr), "a + (b + c)");
369        let expr = app2(Pow, "a", 2.0);
370        assert_eq!(format!("{}", expr), "a ^ 2");
371        let expr = app2(Add, "a", app2(Pow, "b", 2.0));
372        assert_eq!(format!("{}", expr), "a + b ^ 2");
373        let expr = app2(Add, app2(Pow, "a", 2.0), "b");
374        assert_eq!(format!("{}", expr), "a ^ 2 + b");
375        let expr = app2(Div, app2(Add, "a", "b"), "c");
376        assert_eq!(format!("{}", expr), "(a + b) / c");
377        let expr = app2(Sub, "a", "b");
378        assert_eq!(format!("{}", expr), "a - b");
379        let expr = app2(Sub, app2(Sub, "a", "b"), "c");
380        assert_eq!(format!("{}", expr), "a - b - c");
381        let expr = app2(Add, app2(Sub, "a", "b"), "c");
382        assert_eq!(format!("{}", expr), "a - b + c");
383        let expr = app2(Sub, app2(Add, "a", "b"), "c");
384        assert_eq!(format!("{}", expr), "a + b - c");
385        let expr = app2(Mul, app2(Sub, "a", "b"), "c");
386        assert_eq!(format!("{}", expr), "(a - b) * c");
387        let expr = app2(Sub, app2(Mul, "a", "b"), "c");
388        assert_eq!(format!("{}", expr), "a * b - c");
389        let expr = app2(Sub, "a", app2(Mul, "b", "c"));
390        assert_eq!(format!("{}", expr), "a - b * c");
391        let expr = app2(Div, app2(Sub, "a", "b"), "c");
392        assert_eq!(format!("{}", expr), "(a - b) / c");
393        let expr = app2(Sub, app2(Div, "a", "b"), "c");
394        assert_eq!(format!("{}", expr), "a / b - c");
395        let expr = app2(Sub, "a", app2(Div, "b", "c"));
396        assert_eq!(format!("{}", expr), "a - b / c");
397        let expr = app2(Div, "a", "b");
398        assert_eq!(format!("{}", expr), "a / b");
399        let expr = app2(Div, app2(Div, "a", "b"), "c");
400        assert_eq!(format!("{}", expr), "a / b / c");
401        let expr = app2(Eq, app2(Add, "a", "b"), "c");
402        assert_eq!(format!("{}", expr), "a + b = c");
403        let expr = app2(Or, "a", "b");
404        assert_eq!(format!("{}", expr), "a | b");
405        let expr = app2(And, "a", "b");
406        assert_eq!(format!("{}", expr), "a & b");
407        let expr = app2(Or, app2(And, "a", "b"), "c");
408        assert_eq!(format!("{}", expr), "a & b | c");
409        let expr = app2(And, app2(Or, "a", "b"), "c");
410        assert_eq!(format!("{}", expr), "(a | b) & c");
411        let expr = comp("f", "g");
412        assert_eq!(format!("{}", expr), "f · g");
413        let expr = constr("f", "x");
414        assert_eq!(format!("{}", expr), "f{x}");
415        let expr = constr(comp("f", "g"), "x");
416        assert_eq!(format!("{}", expr), "(f · g){x}");
417        let expr = comp("f", comp("g", "h"));
418        assert_eq!(format!("{}", expr), "f · (g · h)");
419        let expr = comp(comp("f", "g"), "h");
420        assert_eq!(format!("{}", expr), "(f · g) · h");
421        let expr = typ("a", "b");
422        assert_eq!(format!("{}", expr), "a : b");
423        let expr = typ(typ("a", "b"), "c");
424        assert_eq!(format!("{}", expr), "(a : b) : c");
425        let expr = typ("a", typ("b", "c"));
426        assert_eq!(format!("{}", expr), "a : (b : c)");
427        let expr = app(Neg, app(Neg, "a"));
428        assert_eq!(format!("{}", expr), "-(-a)");
429        let expr = app(Not, app2(Or, "a", "b"));
430        assert_eq!(format!("{}", expr), "!(a | b)");
431        let expr = app2(Or, app(Not, "a"), "b");
432        assert_eq!(format!("{}", expr), "!a | b");
433        let expr = app2(Or, "a", app(Not, "b"));
434        assert_eq!(format!("{}", expr), "a | !b");
435    }
436
437    struct Rule(Expr);
438
439    impl fmt::Display for Rule {
440        fn fmt(&self, w: &mut fmt::Formatter) -> Result<(), fmt::Error> {
441            let parens = false;
442            let rule = true;
443            self.0.display(w, parens, rule)
444        }
445    }
446
447    #[test]
448    fn constraints() {
449        let rule = Rule(arity_var("f", 1));
450        assert_eq!(format!("{}", rule), "f:[arity]1");
451        let rule = Rule(comp("f", arity_var("g", 1)));
452        assert_eq!(format!("{}", rule), "f · g:[arity]1");
453        let rule = Rule(constr(comp("f", arity_var("g", 1)), "x"));
454        assert_eq!(format!("{}", rule), "(f · g:[arity]1){x}");
455        let rule = Rule(app(comp("f", arity_var("g", 1)), "a"));
456        assert_eq!(format!("{}", rule), "(f · g:[arity]1)(a)");
457        let rule = Rule(path("f", arity_var("g", 1)));
458        assert_eq!(format!("{}", rule), "f[g:[arity]1]");
459        let rule = Rule(app(Neg, arity_var("f", 1)));
460        assert_eq!(format!("{}", rule), "-f:[arity]1");
461        let rule = Rule(app(Not, arity_var("f", 1)));
462        assert_eq!(format!("{}", rule), "!f:[arity]1");
463        let rule = Rule(app(Rty, arity_var("f", 1)));
464        assert_eq!(format!("{}", rule), "(: f:[arity]1)");
465        let rule = Rule(app(Rlt, arity_var("f", 1)));
466        assert_eq!(format!("{}", rule), "(< f:[arity]1)");
467        let rule = Rule(app(Triv, constr(arity_var("f", 1), "g")));
468        assert_eq!(format!("{}", rule), "∀(f:[arity]1{g})");
469    }
470}