use crate::expr::Expr;
use crate::parse::value::value;
use nom::{
branch::alt,
bytes::complete::tag,
character::complete::multispace0,
combinator::{map, map_res},
error::ParseError,
sequence::delimited,
IResult, Parser,
};
pub fn expr(input: &str) -> IResult<&str, Expr> {
lte_expr(input)
}
fn lte_expr(input: &str) -> IResult<&str, Expr> {
alt((binary_expr(lt_expr, "<=", lte_expr, Expr::lte), lt_expr))(input)
}
fn lt_expr(input: &str) -> IResult<&str, Expr> {
alt((binary_expr(gte_expr, "<", lt_expr, Expr::lt), gte_expr))(input)
}
fn gte_expr(input: &str) -> IResult<&str, Expr> {
alt((binary_expr(gt_expr, ">=", gte_expr, Expr::gte), gt_expr))(input)
}
fn gt_expr(input: &str) -> IResult<&str, Expr> {
alt((binary_expr(neq_expr, ">", gt_expr, Expr::gt), neq_expr))(input)
}
fn neq_expr(input: &str) -> IResult<&str, Expr> {
alt((binary_expr(eq_expr, "!=", neq_expr, Expr::neq), eq_expr))(input)
}
fn eq_expr(input: &str) -> IResult<&str, Expr> {
alt((binary_expr(sub_expr, "==", eq_expr, Expr::eq), sub_expr))(input)
}
fn sub_expr(input: &str) -> IResult<&str, Expr> {
alt((binary_expr(add_expr, "-", sub_expr, Expr::sub), add_expr))(input)
}
fn add_expr(input: &str) -> IResult<&str, Expr> {
alt((binary_expr(div_expr, "+", add_expr, Expr::add), div_expr))(input)
}
fn div_expr(input: &str) -> IResult<&str, Expr> {
alt((binary_expr(mult_expr, "/", div_expr, Expr::div), mult_expr))(input)
}
fn mult_expr(input: &str) -> IResult<&str, Expr> {
alt((binary_expr(not_expr, "*", mult_expr, Expr::mult), not_expr))(input)
}
fn not_expr(input: &str) -> IResult<&str, Expr> {
par_expr(input) }
fn par_expr(input: &str) -> IResult<&str, Expr> {
alt((
map_res(
delimited(tag("("), take_until_unbalanced('(', ')'), tag(")")),
|val| expr(val).map(|parsed| parsed.1),
),
idx_expr,
))(input)
}
fn idx_expr(input: &str) -> IResult<&str, Expr> {
value_expr(input) }
fn value_expr(input: &str) -> IResult<&str, Expr> {
map(value, Expr::Value)(input)
}
fn binary_expr<'a, O1, O2, E, F, G, H>(
mut left: F,
sep: &'static str,
mut right: G,
expr: H,
) -> impl FnMut(&'a str) -> IResult<&'a str, Expr, E>
where
E: ParseError<&'a str>,
F: Parser<&'a str, O1, E>,
G: Parser<&'a str, O2, E>,
H: Fn(O1, O2) -> Expr,
{
move |input: &'a str| {
let (input, _) = multispace0(input)?;
let (input, o1) = left.parse(input)?;
let (input, _) = multispace0(input)?;
let (input, _) = tag(sep)(input)?;
let (input, _) = multispace0(input)?;
let (input, o2) = right.parse(input)?;
let (input, _) = multispace0(input)?;
Ok((input, expr(o1, o2)))
}
}
pub fn take_until_unbalanced(
opening_bracket: char,
closing_bracket: char,
) -> impl Fn(&str) -> IResult<&str, &str> {
move |i: &str| {
let mut index = 0;
let mut bracket_counter = 0;
while let Some(n) = &i[index..].find(&[opening_bracket, closing_bracket, '\\'][..]) {
index += n;
let mut it = i[index..].chars();
match it.next().unwrap_or_default() {
c if c == '\\' => {
index += '\\'.len_utf8();
let c = it.next().unwrap_or_default();
index += c.len_utf8();
}
c if c == opening_bracket => {
bracket_counter += 1;
index += opening_bracket.len_utf8();
}
c if c == closing_bracket => {
bracket_counter -= 1;
index += closing_bracket.len_utf8();
}
_ => unreachable!(),
};
if bracket_counter == -1 {
index -= closing_bracket.len_utf8();
return Ok((&i[index..], &i[0..index]));
};
}
if bracket_counter == 0 {
Ok(("", i))
} else {
Err(nom::Err::Error(nom::error::Error::from_error_kind(
i,
nom::error::ErrorKind::TakeUntil,
)))
}
}
}
#[cfg(test)]
mod when_parsing_expressions {
use super::*;
#[test]
fn should_parse_less_than_equal() {
should_parse(expr("12<=4"), Expr::lte(Expr::value(12), Expr::value(4)));
}
#[test]
fn should_parse_less_than() {
should_parse(expr("1<4"), Expr::lt(Expr::value(1), Expr::value(4)));
}
#[test]
fn should_parse_greater_than_equal() {
should_parse(expr("1>=15"), Expr::gte(Expr::value(1), Expr::value(15)));
}
#[test]
fn should_parse_greater_than() {
should_parse(expr("1>4"), Expr::gt(Expr::value(1), Expr::value(4)));
}
#[test]
fn should_parse_not_equal() {
should_parse(expr("14!=8"), Expr::neq(Expr::value(14), Expr::value(8)));
}
#[test]
fn should_parse_equal() {
should_parse(expr("2==8"), Expr::eq(Expr::value(2), Expr::value(8)));
}
#[test]
fn should_parse_subtraction() {
should_parse(expr("12-4"), Expr::sub(Expr::value(12), Expr::value(4)));
}
#[test]
fn should_parse_addition() {
should_parse(expr("1+4"), Expr::add(Expr::value(1), Expr::value(4)));
}
#[test]
fn should_parse_division() {
should_parse(expr("1/15"), Expr::div(Expr::value(1), Expr::value(15)));
}
#[test]
fn should_parse_multiplication() {
should_parse(expr("1*4"), Expr::mult(Expr::value(1), Expr::value(4)));
}
#[test]
fn should_parse_expression_inside_parentheses() {
should_parse(expr("(1+1)"), Expr::add(Expr::value(1), Expr::value(1)));
}
#[test]
fn should_parse_correct_precedence() {
should_parse(
expr("14*128+4/5*3-1"),
Expr::sub(
Expr::add(
Expr::mult(Expr::value(14), Expr::value(128)),
Expr::div(Expr::value(4), Expr::mult(Expr::value(5), Expr::value(3))),
),
Expr::value(1),
),
);
}
#[test]
fn should_ignore_space_and_newlines() {
should_parse(
expr("3 * 4\r\n + 8\r\n "),
Expr::add(Expr::mult(Expr::value(3), Expr::value(4)), Expr::value(8)),
);
}
#[test]
fn should_override_precedence_with_parentheses() {
should_parse(
expr("(3+2)*(1-5)"),
Expr::mult(
Expr::add(Expr::value(3), Expr::value(2)),
Expr::sub(Expr::value(1), Expr::value(5)),
),
);
}
#[ignore]
#[test]
fn should_parse_nested_parentheses() {
todo!()
}
fn should_parse(result: IResult<&str, Expr>, expected_expr: Expr) {
let (rest, parsed) = result.unwrap();
assert_eq!(parsed, expected_expr);
assert_eq!(rest, "");
}
}