elementary-row-operation-verifier 0.0.1

A tool to verify the correctness of elementary row operations on matrices
Documentation
//! nom 表达式解析器
//!
//! 语法(优先级从低到高):
//!   expr  = sum ("/" sum)?
//!   sum   = term (("+" | "-") term)*
//!   term  = unary ("*" unary)*
//!   unary = "-" unary | atom
//!   atom  = NUMBER | FRACTION | VARIABLE | "(" expr ")"
//!
//! / 优先级最低:a / b + c = a / (b + c)

use nom::{
    IResult,
    branch::alt,
    character::complete::{char, digit1, multispace0, one_of},
    combinator::{map, opt},
    multi::many0,
    sequence::{delimited, pair, preceded, tuple},
};

use super::ast::Expr;

// ---- 基础 ----

fn ws(input: &str) -> IResult<&str, &str> {
    multispace0(input)
}

fn integer(input: &str) -> IResult<&str, i64> {
    let (input, sign) = opt(one_of("+-"))(input)?;
    let (input, digits) = digit1(input)?;
    let sign = match sign {
        Some('-') => -1,
        _ => 1,
    };
    let n: i64 = digits.parse().unwrap();
    Ok((input, sign * n))
}

fn fraction(input: &str) -> IResult<&str, Expr> {
    let (input, (num, _, den)) = tuple((integer, char('/'), integer))(input)?;
    Ok((input, Expr::frac(num, den)))
}

fn variable(input: &str) -> IResult<&str, Expr> {
    let (input, first) =
        one_of("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZαβγλμ")(input)?;
    let (input, rest): (&str, Vec<char>) = many0(one_of(
        "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789αβγλμ",
    ))(input)?;
    let name: String = std::iter::once(first).chain(rest).collect();
    Ok((input, Expr::var(&name)))
}

fn number(input: &str) -> IResult<&str, Expr> {
    let (input, n) = integer(input)?;
    Ok((input, Expr::num(n)))
}

// ---- 原子 ----

fn paren(input: &str) -> IResult<&str, Expr> {
    delimited(preceded(ws, char('(')), expr, preceded(ws, char(')')))(input)
}

fn atom(input: &str) -> IResult<&str, Expr> {
    preceded(ws, alt((fraction, number, variable, paren)))(input)
}

// ---- 一元 ----

fn unary(input: &str) -> IResult<&str, Expr> {
    alt((
        map(preceded(ws, pair(char('-'), unary)), |(_, e)| {
            Expr::Neg(Box::new(e))
        }),
        atom,
    ))(input)
}

// ---- 项(隐式乘) ----

fn term(input: &str) -> IResult<&str, Expr> {
    let (input, first) = unary(input)?;
    let (input, rest): (&str, Vec<Expr>) = many0(preceded(ws, unary))(input)?;

    let result = rest
        .into_iter()
        .fold(first, |acc, e| Expr::Mul(Box::new(acc), Box::new(e)));
    Ok((input, result))
}

// ---- 和 ----

fn sum(input: &str) -> IResult<&str, Expr> {
    let (input, first) = term(input)?;
    let (input, rest): (&str, Vec<(char, Expr)>) =
        many0(pair(preceded(ws, one_of("+-")), preceded(ws, term)))(input)?;

    let result = rest.into_iter().fold(first, |acc, (op, e)| match op {
        '+' => Expr::Add(Box::new(acc), Box::new(e)),
        '-' => Expr::Sub(Box::new(acc), Box::new(e)),
        _ => unreachable!(),
    });
    Ok((input, result))
}

// ---- 表达式(/ 最低优先级) ----

pub fn expr(input: &str) -> IResult<&str, Expr> {
    let (input, left) = sum(input)?;
    let (input, right) = opt(preceded(preceded(ws, char('/')), preceded(ws, sum)))(input)?;

    match right {
        Some(r) => Ok((input, Expr::Div(Box::new(left), Box::new(r)))),
        None => Ok((input, left)),
    }
}

// ---- 公开接口 ----

pub fn parse_expr(input: &str) -> Result<Expr, String> {
    match expr(input.trim()) {
        Ok((rest, result)) if rest.trim().is_empty() => Ok(result),
        Ok((rest, _)) => Err(format!("unparsed remainder: '{}'", rest)),
        Err(e) => Err(format!("parse error: {}", e)),
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_number() {
        assert_eq!(parse_expr("3").unwrap(), Expr::num(3));
        assert_eq!(parse_expr("-5").unwrap(), Expr::Neg(Box::new(Expr::num(5))));
    }

    #[test]
    fn test_fraction() {
        assert_eq!(parse_expr("1/2").unwrap(), Expr::frac(1, 2));
    }

    #[test]
    fn test_variable() {
        assert_eq!(parse_expr("lambda").unwrap(), Expr::var("lambda"));
    }

    #[test]
    fn test_add() {
        let result = parse_expr("1 + 2").unwrap();
        assert_eq!(
            result,
            Expr::Add(Box::new(Expr::num(1)), Box::new(Expr::num(2)))
        );
    }

    #[test]
    fn test_implicit_mul() {
        let result = parse_expr("2 m").unwrap();
        assert_eq!(
            result,
            Expr::Mul(Box::new(Expr::num(2)), Box::new(Expr::var("m")))
        );
    }

    #[test]
    fn test_div_low_precedence() {
        // a / b + c = a / (b + c)
        let result = parse_expr("a / b + c").unwrap();
        assert_eq!(
            result,
            Expr::Div(
                Box::new(Expr::var("a")),
                Box::new(Expr::Add(
                    Box::new(Expr::var("b")),
                    Box::new(Expr::var("c"))
                ))
            )
        );
    }

    #[test]
    fn test_paren() {
        let result = parse_expr("2 (2 m / n + p)").unwrap();
        assert_eq!(
            result,
            Expr::Mul(
                Box::new(Expr::num(2)),
                Box::new(Expr::Div(
                    Box::new(Expr::Mul(Box::new(Expr::num(2)), Box::new(Expr::var("m")))),
                    Box::new(Expr::Add(
                        Box::new(Expr::var("n")),
                        Box::new(Expr::var("p"))
                    ))
                ))
            )
        );
    }
}