use crate::ast::*;
use lazy_static::lazy_static;
use pest::prec_climber::{Assoc, Operator, PrecClimber};
use pest::{
iterators::{Pair, Pairs},
Parser,
};
use pest_derive::Parser;
#[derive(Debug, PartialEq, Eq)]
pub enum ParserError {
InvalidNumber(String),
InvalidOperation(String),
InvalidOperand(String),
InvalidExpression(String),
InvalidSymbol(String),
InvalidStatement(String),
EmptyStatement,
MissingAssignmentTarget(String),
MissingAssignment(String),
MissingAssignmentExpression(String),
MissingSolveForLeftExpression(String),
MissingSolveForRightExpression(String),
MissingSolveForSymbol(String),
}
#[derive(Parser)]
#[grammar = "equation.pest"]
pub struct EquationParser;
lazy_static! {
static ref PREC_CLIMBER: PrecClimber<Rule> = {
use Assoc::*;
use Rule::*;
PrecClimber::new(vec![
Operator::new(add, Left) | Operator::new(subtract, Left),
Operator::new(multiply, Left) | Operator::new(divide, Left) | Operator::new(rem, Left),
Operator::new(power, Right),
])
};
}
fn parse_num(pair: Pair<Rule>) -> Result<Operand, ParserError> {
match pair.as_str().parse::<f64>() {
Ok(num) => Ok(Operand::Number(num)),
Err(_) => Err(ParserError::InvalidNumber(pair.as_str().to_string())),
}
}
fn new_operand_term(lhs: Operand, op: Operation, rhs: Operand) -> Operand {
Operand::Term(Box::new(Term { op, lhs, rhs }))
}
fn parse_term(
lhs: Result<Operand, ParserError>,
op: Pair<Rule>,
rhs: Result<Operand, ParserError>,
) -> Result<Operand, ParserError> {
let lhs = lhs?;
let rhs = rhs?;
match op.as_rule() {
Rule::add => Ok(new_operand_term(lhs, Operation::Add, rhs)),
Rule::subtract => Ok(new_operand_term(lhs, Operation::Sub, rhs)),
Rule::multiply => Ok(new_operand_term(lhs, Operation::Mul, rhs)),
Rule::divide => Ok(new_operand_term(lhs, Operation::Div, rhs)),
Rule::rem => Ok(new_operand_term(lhs, Operation::Rem, rhs)),
Rule::power => Ok(new_operand_term(lhs, Operation::Pow, rhs)),
_ => Err(ParserError::InvalidOperation(op.as_str().to_string())),
}
}
fn parse_operand(expression: Pairs<Rule>) -> Result<Operand, ParserError> {
PREC_CLIMBER.climb(
expression,
|pair: Pair<Rule>| match pair.as_rule() {
Rule::num => parse_num(pair),
Rule::expr => parse_operand(pair.into_inner()),
Rule::symbol => Ok(Operand::Symbol(pair.as_str().to_string())),
_ => Err(ParserError::InvalidOperand(pair.as_str().to_string())),
},
parse_term,
)
}
fn parse_assignment(assignment: Pairs<Rule>) -> Result<Statement, ParserError> {
let mut it = assignment;
let sym = it.next().ok_or(ParserError::MissingAssignmentTarget(
it.as_str().to_string(),
))?;
let sym = if Rule::symbol == sym.as_rule() {
Ok(sym.as_str())
} else {
Err(ParserError::InvalidSymbol(sym.as_str().to_string()))
}?;
let sym = sym.to_string();
let op = parse_operand(
it.next()
.ok_or(ParserError::MissingAssignmentExpression(
it.as_str().to_string(),
))?
.into_inner(),
)?;
Ok(Statement::Assignment { sym, op })
}
fn parse_solve_for(solve_for: Pairs<Rule>) -> Result<Statement, ParserError> {
let mut it = solve_for;
let lhs = parse_operand(
it.next()
.ok_or(ParserError::MissingSolveForLeftExpression(
it.as_str().to_string(),
))?
.into_inner(),
)?;
let rhs = parse_operand(
it.next()
.ok_or(ParserError::MissingSolveForRightExpression(
it.as_str().to_string(),
))?
.into_inner(),
)?;
let sym = it
.next()
.ok_or(ParserError::MissingSolveForSymbol(it.as_str().to_string()))?;
let sym = if Rule::symbol == sym.as_rule() {
Ok(sym.as_str())
} else {
Err(ParserError::InvalidSymbol(sym.as_str().to_string()))
}?;
let sym = sym.to_string();
Ok(Statement::SolveFor { lhs, rhs, sym })
}
fn parse_statement(statements: Pairs<Rule>) -> Result<Statement, ParserError> {
for statement in statements {
return match statement.as_rule() {
Rule::assignment => parse_assignment(statement.into_inner()),
Rule::expr => Ok(Statement::Expression {
op: parse_operand(Pairs::single(statement))?,
}),
Rule::solvefor => parse_solve_for(statement.into_inner()),
r => Err(ParserError::InvalidStatement(format!(
"Unexpected rule: {:?}",
r
))),
};
}
return Err(ParserError::EmptyStatement);
}
pub fn parse(cmd: &str) -> Result<Statement, ParserError> {
match EquationParser::parse(Rule::statement, cmd) {
Ok(rules) => parse_statement(rules),
Err(e) => Err(ParserError::InvalidExpression(e.to_string())),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_number() {
let op = Operand::Number(12.2);
assert_eq!(Ok(Statement::Expression { op }), parse("12.2"));
}
#[test]
fn parse_symbol() {
let op = Operand::Symbol("x".to_string());
assert_eq!(Ok(Statement::Expression { op }), parse("x"));
}
#[test]
fn parse_symbol_add() {
let term = {
let lhs = Operand::Symbol("x".to_string());
let rhs = Operand::Number(1.0);
let op = Operation::Add;
Term { op, lhs, rhs }
};
let op = Operand::Term(Box::new(term));
assert_eq!(Ok(Statement::Expression { op }), parse("x + 1"));
}
#[test]
fn parse_term_add() {
let lhs = Operand::Number(3.0);
let rhs = Operand::Number(-4.0);
let op = Operation::Mul;
let op = Operand::Term(Box::new(Term { op, lhs, rhs }));
assert_eq!(Ok(Statement::Expression { op }), parse("3 * -4"));
}
#[test]
fn parse_term_mul() {
let lhs = Operand::Number(1.0);
let rhs = Operand::Number(2.0);
let op = Operation::Add;
let op = Operand::Term(Box::new(Term { op, lhs, rhs }));
assert_eq!(Ok(Statement::Expression { op }), parse("1 + 2"));
}
#[test]
fn parse_term_precedence_add_mul() {
let lhs = Operand::Number(1.0);
let rhs = {
let lhs = Operand::Number(2.0);
let rhs = Operand::Symbol("val".to_string());
let op = Operation::Mul;
Operand::Term(Box::new(Term { op, lhs, rhs }))
};
let op = Operation::Add;
let op = Operand::Term(Box::new(Term { op, lhs, rhs }));
assert_eq!(Ok(Statement::Expression { op }), parse("1 + 2 * val"));
}
#[test]
fn parse_term_precedence_sub_div_pow() {
let lhs = Operand::Number(1.0);
let rhs = {
let lhs = {
let lhs = Operand::Number(2.0);
let rhs = Operand::Symbol("exp".to_string());
let op = Operation::Pow;
Operand::Term(Box::new(Term { op, lhs, rhs }))
};
let rhs = Operand::Symbol("val".to_string());
let op = Operation::Mul;
Operand::Term(Box::new(Term { op, lhs, rhs }))
};
let op = Operation::Add;
let op = Operand::Term(Box::new(Term { op, lhs, rhs }));
assert_eq!(Ok(Statement::Expression { op }), parse("1 + 2 ^ exp * val"));
}
#[test]
fn parse_a_is_1() {
let statement = Statement::Assignment {
sym: "a".to_string(),
op: Operand::Number(1.0),
};
assert_eq!(Ok(statement), parse("a := 1"));
}
#[test]
fn parse_solve_for() {
let statement = Statement::SolveFor {
lhs: Operand::Number(13.0),
rhs: Operand::Symbol("x".to_string()),
sym: "x".to_string(),
};
assert_eq!(Ok(statement), parse("solve 13 = x for x"));
}
}