use crate::grammar::*;
use crate::utils::{unflatten_binary_expr, UnflattenStrategy};
use std::collections::{BTreeMap, VecDeque};
use std::rc::Rc;
pub fn flatten_expr(expr: &Rc<Expr>) -> Rc<Expr> {
match expr.as_ref() {
Expr::Const(_) | Expr::Var(_) => Rc::clone(expr),
Expr::Parend(inner) | Expr::Bracketed(inner) => flatten_expr(inner),
Expr::BinaryExpr(BinaryExpr { op, lhs, rhs })
if op == &BinaryOperator::Plus || op == &BinaryOperator::Minus =>
{
flatten_add_or_sub(lhs, rhs, op == &BinaryOperator::Minus)
}
Expr::BinaryExpr(BinaryExpr { op, lhs, rhs })
if op == &BinaryOperator::Mult || op == &BinaryOperator::Div =>
{
flatten_mul_or_div(lhs, rhs, op == &BinaryOperator::Div)
}
Expr::BinaryExpr(BinaryExpr { op, lhs, rhs }) => {
let lhs = flatten_expr(lhs);
let rhs = flatten_expr(rhs);
Rc::new(Expr::BinaryExpr(BinaryExpr { op: *op, lhs, rhs }))
}
Expr::UnaryExpr(UnaryExpr { op, rhs }) => {
let rhs = flatten_expr(rhs);
Rc::new(Expr::UnaryExpr(UnaryExpr { op: *op, rhs }))
}
}
}
fn flatten_add_or_sub(o_lhs: &Rc<Expr>, o_rhs: &Rc<Expr>, is_subtract: bool) -> Rc<Expr> {
let lhs = flatten_expr(o_lhs);
let rhs = flatten_expr(o_rhs);
let mut coeff = 0.;
let mut terms = BTreeMap::<&Rc<Expr>, f64>::new();
let mut args = VecDeque::with_capacity(2);
let base_args = [lhs, rhs];
args.extend(base_args.iter());
let mut args_before_neg = 1;
while let Some(arg) = args.pop_front() {
let is_neg = is_subtract && args_before_neg == 0;
args_before_neg -= 1;
match arg.as_ref() {
Expr::Const(konst) => {
if is_neg {
coeff -= konst;
} else {
coeff += konst;
}
}
Expr::BinaryExpr(BinaryExpr { op, lhs, rhs }) if op == &BinaryOperator::Plus => {
if is_neg {
args.push_back(lhs);
args.push_back(rhs);
} else {
args.push_front(lhs);
args.push_front(rhs);
args_before_neg += 2;
}
}
_ => {
let entry = terms.entry(arg).or_insert(0.);
if is_neg {
*entry -= 1.;
} else {
*entry += 1.;
}
}
}
}
let mut new_args: Vec<Rc<Expr>> = Vec::with_capacity(1 + terms.len());
if coeff != 0. {
new_args.push(Rc::from(Expr::Const(coeff)));
}
for (term, coeff) in terms {
if coeff == 0. {
continue;
} else if (coeff - 1.).abs() < std::f64::EPSILON {
new_args.push(Rc::clone(term));
} else if (coeff - -1.).abs() < std::f64::EPSILON {
let neg = UnaryExpr::negate(Rc::clone(term));
new_args.push(Rc::from(Expr::UnaryExpr(neg)));
} else {
let mult = BinaryExpr::mult(Expr::Const(coeff), Rc::clone(term));
let expr: Expr = mult.into();
new_args.push(Rc::from(expr));
}
}
match new_args.len() {
0 => Rc::from(Expr::Const(0.)),
1 => new_args.remove(0),
_ => unflatten_binary_expr(&new_args, BinaryOperator::Plus, UnflattenStrategy::Left),
}
}
fn flatten_mul_or_div(o_lhs: &Rc<Expr>, o_rhs: &Rc<Expr>, is_div: bool) -> Rc<Expr> {
let lhs = flatten_expr(o_lhs);
let rhs = flatten_expr(o_rhs);
let mut coeff = 1.;
let mut terms = BTreeMap::<&Rc<Expr>, f64>::new();
let mut args = VecDeque::with_capacity(2);
let base_args = [lhs, rhs];
args.extend(base_args.iter());
let mut args_before_div = if is_div { 1 } else { 2 };
while let Some(arg) = args.pop_front() {
let div_side = args_before_div <= 0;
args_before_div -= 1;
let arg = unwrap_expr(arg);
match arg.as_ref() {
Expr::Const(konst) => {
if div_side {
coeff /= konst;
} else {
coeff *= konst;
}
}
Expr::BinaryExpr(BinaryExpr { op, lhs, rhs })
if op == &BinaryOperator::Mult || op == &BinaryOperator::Div =>
{
if div_side {
if op == &BinaryOperator::Mult {
args.push_back(lhs);
args.push_back(rhs);
} else {
args.push_front(rhs);
args_before_div = 1;
args.push_back(lhs);
}
} else {
if op == &BinaryOperator::Mult {
args.push_front(lhs);
args.push_front(rhs);
args_before_div += 2;
} else {
args.push_front(lhs);
args_before_div += 1;
args.push_back(rhs);
}
}
}
_ => {
let entry = terms.entry(arg).or_insert(0.);
if div_side {
*entry -= 1.;
} else {
*entry += 1.;
}
}
}
}
let mut new_args: Vec<Rc<Expr>> = Vec::with_capacity(1 + terms.len());
if (coeff - 1.).abs() >= std::f64::EPSILON {
new_args.push(Rc::from(Expr::Const(coeff)));
}
for (term, coeff) in terms {
if coeff == 0. {
continue;
} else if (coeff - 1.).abs() < std::f64::EPSILON {
new_args.push(Rc::clone(term));
} else if (coeff - -1.).abs() < std::f64::EPSILON {
let reciprocal = BinaryExpr::div(Expr::Const(1.), Rc::clone(term));
new_args.push(Rc::from(Expr::BinaryExpr(reciprocal)));
} else {
let exponentiation = BinaryExpr::exp(Rc::clone(term), Expr::Const(coeff));
new_args.push(Rc::from(Expr::BinaryExpr(exponentiation)));
}
}
match new_args.len() {
0 => Rc::from(Expr::Const(1.)),
1 => new_args.remove(0),
_ => unflatten_binary_expr(&new_args, BinaryOperator::Mult, UnflattenStrategy::Left),
}
}
fn unwrap_expr(arg: &Rc<Expr>) -> &Rc<Expr> {
match arg.as_ref() {
Expr::Parend(inner) | Expr::Bracketed(inner) => inner,
_ => arg,
}
}
#[cfg(test)]
mod tests {
use super::flatten_expr;
use crate::grammar::*;
use crate::utils::normalize;
use crate::{parse_expression, scan};
use std::rc::Rc;
fn parse(program: &str) -> Expr {
let tokens = scan(program).tokens;
let (parsed, _) = parse_expression(tokens);
match parsed {
Stmt::Expr(expr) => expr,
_ => unreachable!(),
}
}
static CASES: &[&str] = &[
"1 + 2 + 3 -> 6",
"1 + x + x -> (+ 1 (* x 2))",
"x + y + 1 -> (+ (+ x y) 1)",
"x + 0 -> x",
"1 - 1 -> 0",
"1 + 2 - 3 -> 0",
"1 - 2 + 3 -> 2",
"a - a + 1 -> 1",
"a + 1 - 1 -> a",
"10 * 2x / 5 / 2 / 4x -> (* 0.5 (^ x 2))",
"x * 2 / y / (5 / (x / y)) -> (* (* 0.4 (^ x 2)) (^ y -2))",
"x * x -> (^ x 2)",
"x / x -> 1",
"x / x * x / x * x / x -> 1",
];
#[test]
fn flatten_cases() {
for case in CASES {
let mut split = case.split(" -> ");
let expr = parse(split.next().unwrap());
let expected_flattened = split.next().unwrap();
let flattened = normalize(flatten_expr(&Rc::from(expr))).s_form();
assert_eq!(flattened, expected_flattened);
}
}
}