use std::io;
use std::ops::{Add, Div, Mul, Neg, Sub};
use elegance::{
core::Printer,
render::{Io, Render},
};
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
enum Prec {
Top = 0,
AddSub = 1,
MulDiv = 2,
Unary = 3,
Atom = 4,
}
#[derive(Clone, Copy, PartialEq, Eq)]
enum Assoc {
Left,
Right,
}
enum Expr {
Int(i64),
Var(&'static str),
Neg(Box<Expr>),
Add(Box<Expr>, Box<Expr>),
Sub(Box<Expr>, Box<Expr>),
Mul(Box<Expr>, Box<Expr>),
Div(Box<Expr>, Box<Expr>),
}
impl Expr {
fn prec(&self) -> Prec {
match self {
Expr::Int(_) | Expr::Var(_) => Prec::Atom,
Expr::Neg(_) => Prec::Unary,
Expr::Mul(_, _) | Expr::Div(_, _) => Prec::MulDiv,
Expr::Add(_, _) | Expr::Sub(_, _) => Prec::AddSub,
}
}
pub fn print<R: Render>(
&self,
pp: &mut Printer<R, String, (Prec, bool)>,
) -> Result<(), R::Error> {
let (ctx_prec, needs_assoc_parens) = pp.extra;
let my_prec = self.prec();
let needs_parens = my_prec < ctx_prec
|| (my_prec == ctx_prec && my_prec != Prec::Atom && needs_assoc_parens);
pp.cgroup(0, |pp| {
if needs_parens {
pp.text("(")?;
pp.scan_break(0, 2)?;
}
match self {
Expr::Int(n) => pp.text_owned(n.to_string())?,
Expr::Var(name) => pp.text(*name)?,
Expr::Neg(e) => {
pp.text("-")?;
pp.extra = (Prec::Unary, false);
e.print(pp)?;
}
Expr::Add(lhs, rhs) => {
self.print_binop(pp, lhs, "+", rhs, Prec::AddSub, Assoc::Left)?
}
Expr::Sub(lhs, rhs) => {
self.print_binop(pp, lhs, "-", rhs, Prec::AddSub, Assoc::Left)?
}
Expr::Mul(lhs, rhs) => {
self.print_binop(pp, lhs, "*", rhs, Prec::MulDiv, Assoc::Left)?
}
Expr::Div(lhs, rhs) => {
self.print_binop(pp, lhs, "/", rhs, Prec::MulDiv, Assoc::Left)?
}
}
if needs_parens {
pp.scan_break(0, 0)?;
pp.text(")")?;
}
Ok(())
})?;
pp.extra = (ctx_prec, needs_assoc_parens);
Ok(())
}
fn print_binop<R: Render>(
&self,
pp: &mut Printer<R, String, (Prec, bool)>,
lhs: &Expr,
op: &'static str,
rhs: &Expr,
prec: Prec,
assoc: Assoc,
) -> Result<(), R::Error> {
pp.igroup(0, |pp| {
pp.extra = (prec, assoc == Assoc::Right);
lhs.print(pp)?;
pp.space()?;
pp.text(op)?;
pp.scan_break(1, 2)?;
pp.extra = (prec, assoc == Assoc::Left);
rhs.print(pp)
})
}
}
type E = Box<Expr>;
fn v(name: &'static str) -> E {
Box::new(Expr::Var(name))
}
impl From<i64> for Box<Expr> {
fn from(n: i64) -> Self {
Box::new(Expr::Int(n))
}
}
macro_rules! impl_op {
($($op: path, $f: ident, $c:ident);* $(;)?) => {
$( impl $op for Box<Expr> {
type Output = Self;
fn $f(self, rhs: Self) -> Self::Output {
Box::new(Expr::$c(self, rhs))
}
} )*
};
}
impl_op! {
Add, add, Add;
Sub, sub, Sub;
Mul, mul, Mul;
Div, div, Div;
}
impl Neg for Box<Expr> {
type Output = Self;
fn neg(self) -> Self {
Box::new(Expr::Neg(self))
}
}
fn main() -> io::Result<()> {
let (a, b, c, d, e, f) = (v("a"), v("b"), v("c"), v("d"), v("e"), v("f"));
let examples: &[(&str, E)] = &[
("1 + 2 * 3", E::from(1) + E::from(2) * E::from(3)),
("(1 + 2) * 3", (E::from(1) + E::from(2)) * E::from(3)),
("a - b - c (left assoc)", (v("a") - v("b")) - v("c")),
("a - (b - c)", v("a") - (v("b") - v("c"))),
("a / b * c", v("a") / v("b") * v("c")),
("a / (b * c)", v("a") / (v("b") * v("c"))),
("(a + b) * (c - d) / (e + f)", (a + b) * (c - d) / (e + f)),
("-a * b", -v("a") * v("b")),
("-(a * b)", -(v("a") * v("b"))),
(
"long expr",
v("alpha") * v("beta")
+ v("gamma") * v("delta")
+ v("epsilon") * v("zeta")
+ v("eta") / v("theta"),
),
(
"wrap with parens",
v("foo") * (v("alpha") + v("beta") + v("gamma") + v("delta"))
/ (v("epsilon") - v("zeta")),
),
];
for (desc, expr) in examples {
println!("{}:", desc);
let mut printer = Printer::new_extra(Io(io::stdout()), 40, (Prec::Top, false));
expr.print(&mut printer)?;
printer.finish()?;
println!();
println!();
}
Ok(())
}