use nom::{
IResult,
branch::alt,
character::complete::{char, digit1, multispace0, one_of},
combinator::{map, opt},
multi::many0,
sequence::{delimited, pair, preceded, tuple},
};
use super::ast::Expr;
fn ws(input: &str) -> IResult<&str, &str> {
multispace0(input)
}
fn integer(input: &str) -> IResult<&str, i64> {
let (input, sign) = opt(one_of("+-"))(input)?;
let (input, digits) = digit1(input)?;
let sign = match sign {
Some('-') => -1,
_ => 1,
};
let n: i64 = digits.parse().unwrap();
Ok((input, sign * n))
}
fn fraction(input: &str) -> IResult<&str, Expr> {
let (input, (num, _, den)) = tuple((integer, char('/'), integer))(input)?;
Ok((input, Expr::frac(num, den)))
}
fn variable(input: &str) -> IResult<&str, Expr> {
let (input, first) =
one_of("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZαβγλμ")(input)?;
let (input, rest): (&str, Vec<char>) = many0(one_of(
"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789αβγλμ",
))(input)?;
let name: String = std::iter::once(first).chain(rest).collect();
Ok((input, Expr::var(&name)))
}
fn number(input: &str) -> IResult<&str, Expr> {
let (input, n) = integer(input)?;
Ok((input, Expr::num(n)))
}
fn paren(input: &str) -> IResult<&str, Expr> {
delimited(preceded(ws, char('(')), expr, preceded(ws, char(')')))(input)
}
fn atom(input: &str) -> IResult<&str, Expr> {
preceded(ws, alt((fraction, number, variable, paren)))(input)
}
fn unary(input: &str) -> IResult<&str, Expr> {
alt((
map(preceded(ws, pair(char('-'), unary)), |(_, e)| {
Expr::Neg(Box::new(e))
}),
atom,
))(input)
}
fn term(input: &str) -> IResult<&str, Expr> {
let (input, first) = unary(input)?;
let (input, rest): (&str, Vec<Expr>) = many0(preceded(ws, unary))(input)?;
let result = rest
.into_iter()
.fold(first, |acc, e| Expr::Mul(Box::new(acc), Box::new(e)));
Ok((input, result))
}
fn sum(input: &str) -> IResult<&str, Expr> {
let (input, first) = term(input)?;
let (input, rest): (&str, Vec<(char, Expr)>) =
many0(pair(preceded(ws, one_of("+-")), preceded(ws, term)))(input)?;
let result = rest.into_iter().fold(first, |acc, (op, e)| match op {
'+' => Expr::Add(Box::new(acc), Box::new(e)),
'-' => Expr::Sub(Box::new(acc), Box::new(e)),
_ => unreachable!(),
});
Ok((input, result))
}
pub fn expr(input: &str) -> IResult<&str, Expr> {
let (input, left) = sum(input)?;
let (input, right) = opt(preceded(preceded(ws, char('/')), preceded(ws, sum)))(input)?;
match right {
Some(r) => Ok((input, Expr::Div(Box::new(left), Box::new(r)))),
None => Ok((input, left)),
}
}
pub fn parse_expr(input: &str) -> Result<Expr, String> {
match expr(input.trim()) {
Ok((rest, result)) if rest.trim().is_empty() => Ok(result),
Ok((rest, _)) => Err(format!("unparsed remainder: '{}'", rest)),
Err(e) => Err(format!("parse error: {}", e)),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_number() {
assert_eq!(parse_expr("3").unwrap(), Expr::num(3));
assert_eq!(parse_expr("-5").unwrap(), Expr::Neg(Box::new(Expr::num(5))));
}
#[test]
fn test_fraction() {
assert_eq!(parse_expr("1/2").unwrap(), Expr::frac(1, 2));
}
#[test]
fn test_variable() {
assert_eq!(parse_expr("lambda").unwrap(), Expr::var("lambda"));
}
#[test]
fn test_add() {
let result = parse_expr("1 + 2").unwrap();
assert_eq!(
result,
Expr::Add(Box::new(Expr::num(1)), Box::new(Expr::num(2)))
);
}
#[test]
fn test_implicit_mul() {
let result = parse_expr("2 m").unwrap();
assert_eq!(
result,
Expr::Mul(Box::new(Expr::num(2)), Box::new(Expr::var("m")))
);
}
#[test]
fn test_div_low_precedence() {
let result = parse_expr("a / b + c").unwrap();
assert_eq!(
result,
Expr::Div(
Box::new(Expr::var("a")),
Box::new(Expr::Add(
Box::new(Expr::var("b")),
Box::new(Expr::var("c"))
))
)
);
}
#[test]
fn test_paren() {
let result = parse_expr("2 (2 m / n + p)").unwrap();
assert_eq!(
result,
Expr::Mul(
Box::new(Expr::num(2)),
Box::new(Expr::Div(
Box::new(Expr::Mul(Box::new(Expr::num(2)), Box::new(Expr::var("m")))),
Box::new(Expr::Add(
Box::new(Expr::var("n")),
Box::new(Expr::var("p"))
))
))
)
);
}
}