1use std::fmt;
2use std::ops::{Add, Sub, Mul};
3
4use super::*;
5
6#[derive(Clone, PartialEq, PartialOrd, Debug)]
8pub enum Expr {
9 Sym(Symbol),
11 Ret(Value),
15 EOp(Op, Box<Expr>, Box<Expr>),
17 Tup(Vec<Expr>),
19 List(Vec<Expr>),
21}
22
23impl Add for Expr {
24 type Output = Expr;
25 fn add(self, other: Expr) -> Expr {app2(Add, self, other)}
26}
27
28impl Sub for Expr {
29 type Output = Expr;
30 fn sub(self, other: Expr) -> Expr {app2(Sub, self, other)}
31}
32
33impl Mul for Expr {
34 type Output = Expr;
35 fn mul(self, other: Expr) -> Expr {app2(Mul, self, other)}
36}
37
38impl Expr {
39 pub fn display(
41 &self,
42 w: &mut fmt::Formatter<'_>,
43 parens: bool,
44 rule: bool,
45 ) -> std::result::Result<(), fmt::Error> {
46 match self {
47 Sym(s) => s.display(w, rule)?,
48 Ret(v) => write!(w, "{}", v)?,
49 EOp(Path, a, b) => {
50 if let Tup(b) = &**b {
51 let parens = true;
52 a.display(w, parens, rule)?;
53 write!(w, "[")?;
54 for i in 0..b.len() {
55 if i > 0 {
56 if i + 1 < b.len() {
57 write!(w, " ⨯ ")?
58 } else {
59 write!(w, " → ")?
60 }
61 }
62 b[i].display(w, true, rule)?;
63 }
64 write!(w, "]")?
65 } else {
66 a.display(w, true, rule)?;
67 write!(w, "[")?;
68 b.display(w, false, rule)?;
69 write!(w, "]")?;
70 }
71 }
72 EOp(Apply, a, b) => {
73 let mut r = |op: &str| -> std::result::Result<(), fmt::Error> {
74 write!(w, "({} ", op)?;
75 b.display(w, false, rule)?;
76 write!(w, ")")
77 };
78 if let Sym(Neg) = **a {
79 if parens {
80 write!(w, "(")?;
81 }
82 write!(w, "-")?;
83 b.display(w, true, rule)?;
84 if parens {
85 write!(w, ")")?;
86 }
87 } else if let Sym(Not) = **a {
88 if parens {
89 write!(w, "(")?;
90 }
91 write!(w, "!")?;
92 b.display(w, true, rule)?;
93 if parens {
94 write!(w, ")")?;
95 }
96 } else if let Sym(Rty) = **a {
97 if let Sym(_) = **b {
98 r(":")?;
99 }
100 } else if let Sym(Rlt) = **a {
101 r("<")?;
102 } else if let Sym(Rle) = **a {
103 r("<=")?;
104 } else if let Sym(Eq) = **a {
105 r("=")?;
106 } else if let Sym(Rgt) = **a {
107 r(">")?;
108 } else if let Sym(Rge) = **a {
109 r(">=")?;
110 } else if let Sym(Mul) = **a {
111 r("*")?;
112 } else if let Sym(Add) = **a {
113 r("+")?;
114 } else if let Sym(Rsub) = **a {
115 r("-")?;
116 } else if let Sym(Rdiv) = **a {
117 r("/")?;
118 } else if let Sym(Rpow) = **a {
119 r("^")?;
120 } else {
121 if let (EOp(Apply, f, a), Sym(Pi)) = (&**a, &**b) {
122 if let (Sym(Mul), Ret(F64(a))) = (&**f, &**a) {
123 write!(w, "{}π", a)?;
124 return Ok(())
125 }
126 }
127 if let (EOp(Apply, f, a), Sym(Tau)) = (&**a, &**b) {
128 if let (Sym(Mul), Ret(F64(a))) = (&**f, &**a) {
129 write!(w, "{}τ", a)?;
130 return Ok(())
131 }
132 }
133 if let (EOp(Apply, f, a), Sym(Eps)) = (&**a, &**b) {
134 if let (Sym(Mul), Ret(F64(a))) = (&**f, &**a) {
135 write!(w, "{}ε", a)?;
136 return Ok(())
137 }
138 }
139 if let (EOp(Apply, f, a), Sym(Imag)) = (&**a, &**b) {
140 if let (Sym(Mul), Ret(F64(a))) = (&**f, &**a) {
141 write!(w, "{}𝐢", a)?;
142 return Ok(())
143 }
144 }
145 if let (EOp(Apply, f, a), Sym(Imag2)) = (&**a, &**b) {
146 if let (Sym(Mul), Ret(F64(a))) = (&**f, &**a) {
147 write!(w, "{}𝐢₂", a)?;
148 return Ok(())
149 }
150 }
151 if let (EOp(Apply, f, a), Sym(Imag3)) = (&**a, &**b) {
152 if let (Sym(Mul), Ret(F64(a))) = (&**f, &**a) {
153 write!(w, "{}𝐢₃", a)?;
154 return Ok(())
155 }
156 }
157 if let (EOp(Apply, f, b), Sym(Var(ref a))) = (&**a, &**b) {
158 if let (Sym(Mul), Sym(Pariv)) = (&**f, &**b) {
159 write!(w, "∂{}", a)?;
160 return Ok(())
161 }
162 }
163 if let EOp(Apply, f, a) = &**a {
164 let mut pr = |
165 op_txt: &str,
166 op_sym: &Symbol
167 | -> std::result::Result<(), fmt::Error> {
168 if parens {write!(w, "(")?};
169 let left = true;
170 a.display(w, a.needs_parens(op_sym, left), rule)?;
171 write!(w, " {} ", op_txt)?;
172 let right = false;
173 b.display(w, b.needs_parens(op_sym, right), rule)?;
174 if parens {write!(w, ")")?};
175 Ok(())
176 };
177
178 match **f {
179 Sym(Add) => {
180 pr("+", &Add)?;
181 return Ok(())
182 }
183 Sym(Sub) => {
184 pr("-", &Sub)?;
185 return Ok(())
186 }
187 Sym(Mul) => {
188 pr("*", &Mul)?;
189 return Ok(())
190 }
191 Sym(Div) => {
192 pr("/", &Div)?;
193 return Ok(())
194 }
195 Sym(Rem) => {
196 pr("%", &Rem)?;
197 return Ok(())
198 }
199 Sym(Pow) => {
200 pr("^", &Pow)?;
201 return Ok(())
202 }
203 Sym(And) => {
204 pr("&", &And)?;
205 return Ok(())
206 }
207 Sym(Or) => {
208 pr("|", &Or)?;
209 return Ok(())
210 }
211 Sym(Concat) => {
212 pr("++", &Concat)?;
213 return Ok(())
214 }
215 Sym(Lt) => {
216 pr("<", &Lt)?;
217 return Ok(())
218 }
219 Sym(Le) => {
220 pr("<=", &Le)?;
221 return Ok(())
222 }
223 Sym(Eq) => {
224 pr("=", &Eq)?;
225 return Ok(())
226 }
227 Sym(Gt) => {
228 pr(">", &Gt)?;
229 return Ok(())
230 }
231 Sym(Ge) => {
232 pr(">=", &Ge)?;
233 return Ok(())
234 }
235 _ => {}
236 }
237 }
238
239 if let Ret(_) = **a {
240 write!(w, "\\")?;
241 }
242 let parens = true;
243 a.display(w, parens, rule)?;
244 if let Tup(_) = &**b {
245 b.display(w, parens, rule)?;
246 } else {
247 write!(w, "(")?;
248 b.display(w, false, rule)?;
249 write!(w, ")")?;
250 }
251 }
252 }
253 EOp(Constrain, a, b) => {
254 if let Ret(_) = **a {
255 write!(w, "\\")?;
256 }
257 a.display(w, true, rule)?;
258 if let Tup(b) = &**b {
259 write!(w, "{{")?;
260 for i in 0..b.len() {
261 if i > 0 {write!(w, ", ")?}
262 b[i].display(w, false, rule)?;
263 }
264 write!(w, "}}")?;
265 } else {
266 write!(w, "{{")?;
267 b.display(w, false, rule)?;
268 write!(w, "}}")?;
269 }
270 }
271 EOp(Compose, a, b) => {
272 if parens {
273 write!(w, "(")?;
274 }
275 a.display(w, true, rule)?;
276 write!(w, " · ")?;
277 b.display(w, true, rule)?;
278 if parens {
279 write!(w, ")")?;
280 }
281 }
282 EOp(Type, a, b) => {
283 if parens {
284 write!(w, "(")?;
285 }
286 a.display(w, true, rule)?;
287 write!(w, " : ")?;
288 b.display(w, true, rule)?;
289 if parens {
290 write!(w, ")")?;
291 }
292 }
293 Tup(b) => {
294 write!(w, "(")?;
295 for i in 0..b.len() {
296 if i > 0 {write!(w, ", ")?}
297 b[i].display(w, false, rule)?;
298 }
299 write!(w, ")")?;
300 }
301 List(b) => {
302 write!(w, "[")?;
303 for i in 0..b.len() {
304 if i > 0 {write!(w, ", ")?}
305 b[i].display(w, false, rule)?;
306 }
307 write!(w, "]")?;
308 }
309 }
311 Ok(())
312 }
313
314 pub fn needs_parens(&self, parent_op: &Symbol, left: bool) -> bool {
316 if let EOp(Apply, f, _) = self {
317 if let EOp(Apply, f, _) = &**f {
318 match &**f {
319 Sym(x) => {
320 if let (Some(x), Some(y)) = (x.precedence(), parent_op.precedence()) {
321 if left {x > y} else {x >= y}
322 } else {true}
323 }
324 _ => true
325 }
326 } else {
327 match &**f {
328 Sym(x) => {
329 if let (Some(x), Some(y)) = (x.precedence(), parent_op.precedence()) {
330 if left {x > y} else {x >= y}
331 } else {true}
332 }
333 _ => true
334 }
335 }
336 } else {
337 true
338 }
339 }
340}
341
342impl fmt::Display for Expr {
343 fn fmt(&self, w: &mut fmt::Formatter<'_>) -> std::result::Result<(), fmt::Error> {
344 let parens = false;
345 let rule = false;
346 self.display(w, parens, rule)
347 }
348}
349
350#[cfg(test)]
351mod tests {
352 use crate::*;
353 use std::fmt;
354
355 #[test]
356 fn parens() {
357 let expr = app2(Mul, app2(Mul, "a", "b"), "c");
358 assert_eq!(format!("{}", expr), "a * b * c");
359 let expr = app2(Mul, "a", app2(Mul, "b", "c"));
360 assert_eq!(format!("{}", expr), "a * (b * c)");
361 let expr = app2(Add, "a", "b");
362 assert_eq!(format!("{}", expr), "a + b");
363 let expr = app2(Mul, app2(Add, "a", "b"), "c");
364 assert_eq!(format!("{}", expr), "(a + b) * c");
365 let expr = app2(Add, app2(Add, "a", "b"), "c");
366 assert_eq!(format!("{}", expr), "a + b + c");
367 let expr = app2(Add, "a", app2(Add, "b", "c"));
368 assert_eq!(format!("{}", expr), "a + (b + c)");
369 let expr = app2(Pow, "a", 2.0);
370 assert_eq!(format!("{}", expr), "a ^ 2");
371 let expr = app2(Add, "a", app2(Pow, "b", 2.0));
372 assert_eq!(format!("{}", expr), "a + b ^ 2");
373 let expr = app2(Add, app2(Pow, "a", 2.0), "b");
374 assert_eq!(format!("{}", expr), "a ^ 2 + b");
375 let expr = app2(Div, app2(Add, "a", "b"), "c");
376 assert_eq!(format!("{}", expr), "(a + b) / c");
377 let expr = app2(Sub, "a", "b");
378 assert_eq!(format!("{}", expr), "a - b");
379 let expr = app2(Sub, app2(Sub, "a", "b"), "c");
380 assert_eq!(format!("{}", expr), "a - b - c");
381 let expr = app2(Add, app2(Sub, "a", "b"), "c");
382 assert_eq!(format!("{}", expr), "a - b + c");
383 let expr = app2(Sub, app2(Add, "a", "b"), "c");
384 assert_eq!(format!("{}", expr), "a + b - c");
385 let expr = app2(Mul, app2(Sub, "a", "b"), "c");
386 assert_eq!(format!("{}", expr), "(a - b) * c");
387 let expr = app2(Sub, app2(Mul, "a", "b"), "c");
388 assert_eq!(format!("{}", expr), "a * b - c");
389 let expr = app2(Sub, "a", app2(Mul, "b", "c"));
390 assert_eq!(format!("{}", expr), "a - b * c");
391 let expr = app2(Div, app2(Sub, "a", "b"), "c");
392 assert_eq!(format!("{}", expr), "(a - b) / c");
393 let expr = app2(Sub, app2(Div, "a", "b"), "c");
394 assert_eq!(format!("{}", expr), "a / b - c");
395 let expr = app2(Sub, "a", app2(Div, "b", "c"));
396 assert_eq!(format!("{}", expr), "a - b / c");
397 let expr = app2(Div, "a", "b");
398 assert_eq!(format!("{}", expr), "a / b");
399 let expr = app2(Div, app2(Div, "a", "b"), "c");
400 assert_eq!(format!("{}", expr), "a / b / c");
401 let expr = app2(Eq, app2(Add, "a", "b"), "c");
402 assert_eq!(format!("{}", expr), "a + b = c");
403 let expr = app2(Or, "a", "b");
404 assert_eq!(format!("{}", expr), "a | b");
405 let expr = app2(And, "a", "b");
406 assert_eq!(format!("{}", expr), "a & b");
407 let expr = app2(Or, app2(And, "a", "b"), "c");
408 assert_eq!(format!("{}", expr), "a & b | c");
409 let expr = app2(And, app2(Or, "a", "b"), "c");
410 assert_eq!(format!("{}", expr), "(a | b) & c");
411 let expr = comp("f", "g");
412 assert_eq!(format!("{}", expr), "f · g");
413 let expr = constr("f", "x");
414 assert_eq!(format!("{}", expr), "f{x}");
415 let expr = constr(comp("f", "g"), "x");
416 assert_eq!(format!("{}", expr), "(f · g){x}");
417 let expr = comp("f", comp("g", "h"));
418 assert_eq!(format!("{}", expr), "f · (g · h)");
419 let expr = comp(comp("f", "g"), "h");
420 assert_eq!(format!("{}", expr), "(f · g) · h");
421 let expr = typ("a", "b");
422 assert_eq!(format!("{}", expr), "a : b");
423 let expr = typ(typ("a", "b"), "c");
424 assert_eq!(format!("{}", expr), "(a : b) : c");
425 let expr = typ("a", typ("b", "c"));
426 assert_eq!(format!("{}", expr), "a : (b : c)");
427 let expr = app(Neg, app(Neg, "a"));
428 assert_eq!(format!("{}", expr), "-(-a)");
429 let expr = app(Not, app2(Or, "a", "b"));
430 assert_eq!(format!("{}", expr), "!(a | b)");
431 let expr = app2(Or, app(Not, "a"), "b");
432 assert_eq!(format!("{}", expr), "!a | b");
433 let expr = app2(Or, "a", app(Not, "b"));
434 assert_eq!(format!("{}", expr), "a | !b");
435 }
436
437 struct Rule(Expr);
438
439 impl fmt::Display for Rule {
440 fn fmt(&self, w: &mut fmt::Formatter) -> Result<(), fmt::Error> {
441 let parens = false;
442 let rule = true;
443 self.0.display(w, parens, rule)
444 }
445 }
446
447 #[test]
448 fn constraints() {
449 let rule = Rule(arity_var("f", 1));
450 assert_eq!(format!("{}", rule), "f:[arity]1");
451 let rule = Rule(comp("f", arity_var("g", 1)));
452 assert_eq!(format!("{}", rule), "f · g:[arity]1");
453 let rule = Rule(constr(comp("f", arity_var("g", 1)), "x"));
454 assert_eq!(format!("{}", rule), "(f · g:[arity]1){x}");
455 let rule = Rule(app(comp("f", arity_var("g", 1)), "a"));
456 assert_eq!(format!("{}", rule), "(f · g:[arity]1)(a)");
457 let rule = Rule(path("f", arity_var("g", 1)));
458 assert_eq!(format!("{}", rule), "f[g:[arity]1]");
459 let rule = Rule(app(Neg, arity_var("f", 1)));
460 assert_eq!(format!("{}", rule), "-f:[arity]1");
461 let rule = Rule(app(Not, arity_var("f", 1)));
462 assert_eq!(format!("{}", rule), "!f:[arity]1");
463 let rule = Rule(app(Rty, arity_var("f", 1)));
464 assert_eq!(format!("{}", rule), "(: f:[arity]1)");
465 let rule = Rule(app(Rlt, arity_var("f", 1)));
466 assert_eq!(format!("{}", rule), "(< f:[arity]1)");
467 let rule = Rule(app(Triv, constr(arity_var("f", 1), "g")));
468 assert_eq!(format!("{}", rule), "∀(f:[arity]1{g})");
469 }
470}