1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
use std::fmt;

#[derive(PartialEq, Eq, Debug)]
pub enum Expression {
    /// 92
    Number(i64),
    /// sqr 42
    Sqr(i64),
    /// + 3 4
    BinExpression { bin_op: BinOp, lhs: i64, rhs: i64 },
}

#[derive(PartialEq, Eq, Debug)]
pub enum BinOp {
    Add,
    Sub,
    Mul,
    Div,
}

#[derive(PartialEq, Eq, Debug)]
pub enum CalcError {
    Eof,
    BadToken,
    ArithError,
}

impl fmt::Display for CalcError {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        match self {
            CalcError::Eof => write!(f, "eof"),
            CalcError::BadToken => write!(f, "bad token"),
            CalcError::ArithError => write!(f, "arithmetic error"),
        }
    }
}

impl std::error::Error for CalcError {}

/// Parse input into an `Expression`.
pub fn parse(input: &str) -> Result<Expression, CalcError> {
    let tokens = input.split_ascii_whitespace().collect::<Vec<&str>>();
    let first = match tokens.get(0) {
        Some(it) => *it,
        None => return Err(CalcError::Eof),
    };

    let bin_op = match first {
        "+" => BinOp::Add,
        "-" => BinOp::Sub,
        "*" => BinOp::Mul,
        "/" => BinOp::Div,
        "sqr" => {
            let arg = match tokens.get(1) {
                None => return Err(CalcError::Eof),
                Some(token) => match token.parse::<i64>() {
                    Ok(n) => n,
                    Err(_) => return Err(CalcError::BadToken),
                },
            };
            return Ok(Expression::Sqr(arg));
        }
        _ => {
            return match first.parse::<i64>() {
                Ok(n) => Ok(Expression::Number(n)),
                Err(_) => Err(CalcError::BadToken),
            }
        }
    };
    let lhs = match tokens.get(1) {
        None => return Err(CalcError::Eof),
        Some(token) => match token.parse::<i64>() {
            Ok(n) => n,
            Err(_) => return Err(CalcError::BadToken),
        },
    };
    let rhs = match tokens.get(2) {
        None => return Err(CalcError::Eof),
        Some(token) => match token.parse::<i64>() {
            Ok(n) => n,
            Err(_) => return Err(CalcError::BadToken),
        },
    };

    Ok(Expression::BinExpression { bin_op, lhs, rhs })
}

/// Evaluates an expression.
pub fn eval(expr: Expression) -> Result<i64, CalcError> {
    let res = match expr {
        Expression::Number(n) => n,
        Expression::Sqr(n) => n * n,
        Expression::BinExpression { bin_op, lhs, rhs } => {
            let op_res = match bin_op {
                BinOp::Add => lhs.checked_add(rhs),
                BinOp::Sub => lhs.checked_sub(rhs),
                BinOp::Mul => lhs.checked_mul(rhs),
                BinOp::Div => lhs.checked_div(rhs),
            };

            match op_res {
                Some(it) => it,
                None => return Err(CalcError::ArithError),
            }
        }
    };
    Ok(res)
}

#[test]
fn test() {
    assert_eq!(parse("92"), Ok(Expression::Number(92)),);
    assert_eq!(parse("sqr 92"), Ok(Expression::Sqr(92)),);
    assert_eq!(
        parse("+ 90 2"),
        Ok(Expression::BinExpression {
            bin_op: BinOp::Add,
            lhs: 90,
            rhs: 2,
        }),
    );
}