die_sir/parsedie/
ast.rs

1use rand::rngs::OsRng;
2use rand::Rng;
3use std::error;
4
5#[derive(Clone, Debug, PartialEq)]
6pub enum Node {
7    Add(Box<Node>, Box<Node>),
8    Subtract(Box<Node>, Box<Node>),
9    Multiply(Box<Node>, Box<Node>),
10    Divide(Box<Node>, Box<Node>),
11    Caret(Box<Node>, Box<Node>),
12    Negative(Box<Node>),
13    Die(Box<Node>, Box<Node>),
14    Number(i128),
15}
16
17#[derive(Clone, Debug, PartialEq)]
18pub enum EvalResult {
19    Number(f64),
20    DieResult(Vec<i128>),
21}
22
23impl std::fmt::Display for EvalResult {
24    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
25        match &self {
26            self::EvalResult::Number(e) => write!(f, "{}", e),
27            self::EvalResult::DieResult(e) => write!(f, "{:?}", e),
28        }
29    }
30}
31
32pub fn eval(expr: Node) -> Result<EvalResult, Box<dyn error::Error>> {
33    match expr {
34        Node::Number(i) => Ok(EvalResult::Number(i as f64)),
35        Node::Add(expr1, expr2) => {
36            let lhs = eval(*expr1)?;
37            let rhs = eval(*expr2)?;
38
39            match (lhs, rhs) {
40                (EvalResult::Number(num1), EvalResult::Number(num2)) => {
41                    Ok(EvalResult::Number(num1 + num2))
42                }
43                _ => Err("Cannot Add DieResult to Number".into()),
44            }
45        }
46        Node::Subtract(expr1, expr2) => {
47            let lhs = eval(*expr1)?;
48            let rhs = eval(*expr2)?;
49
50            match (lhs, rhs) {
51                (EvalResult::Number(num1), EvalResult::Number(num2)) => {
52                    Ok(EvalResult::Number(num1 - num2))
53                }
54                _ => Err("Cannot Subtract DieResult to Number".into()),
55            }
56        }
57        Node::Multiply(expr1, expr2) => {
58            let lhs = eval(*expr1)?;
59            let rhs = eval(*expr2)?;
60
61            match (lhs, rhs) {
62                (EvalResult::Number(num1), EvalResult::Number(num2)) => {
63                    Ok(EvalResult::Number(num1 * num2))
64                }
65                _ => Err("Cannot Multiply DieResult by Number".into()),
66            }
67        }
68        Node::Divide(expr1, expr2) => {
69            let lhs = eval(*expr1)?;
70            let rhs = eval(*expr2)?;
71
72            match (lhs, rhs) {
73                (EvalResult::Number(num1), EvalResult::Number(num2)) => {
74                    Ok(EvalResult::Number(num1 / num2))
75                }
76                _ => Err("Cannot Divide DieResult by Number".into()),
77            }
78        }
79        Node::Negative(expr1) => {
80            let value = eval(*expr1)?;
81
82            match value {
83                EvalResult::Number(val) => Ok(EvalResult::Number(-val)),
84                _ => Err("DieResult cannot be negative".into()),
85            }
86        }
87        Node::Caret(expr1, expr2) => {
88            let lhs = eval(*expr1)?;
89            let rhs = eval(*expr2)?;
90
91            match (lhs, rhs) {
92                (EvalResult::Number(num1), EvalResult::Number(num2)) => {
93                    Ok(EvalResult::Number(num1.powf(num2)))
94                }
95                _ => Err("Cannot use DieResult in power operation".into()),
96            }
97        }
98        Node::Die(expr1, expr2) => {
99            let num_rolls = eval(*expr1)?;
100            let num_sides = eval(*expr2)?;
101            let mut results: Vec<i128> = Vec::new();
102            let mut rng = OsRng;
103
104            if let (EvalResult::Number(num_rolls), EvalResult::Number(num_sides)) =
105                (num_rolls, num_sides)
106            {
107                if num_rolls == 0.0 || num_sides == 0.0 {
108                    return Ok(EvalResult::Number(0.0));
109                } else if num_rolls == 1.0 {
110                    return Ok(EvalResult::Number(
111                        rng.gen_range(1..=(num_sides as i128)) as f64
112                    ));
113                }
114
115                for _ in 0..(num_rolls as i128) {
116                    results.push(rng.gen_range(1..=(num_sides as i128)));
117                }
118
119                Ok(EvalResult::DieResult(results))
120            } else {
121                Err("Die expressions must have numeric operands".into())
122            }
123        }
124    }
125}
126
127#[cfg(test)]
128mod tests {
129    use super::*;
130    #[test]
131    fn test_expr1() {
132        use crate::parsedie::parser::Parser;
133
134        let ast = Parser::new("1+2-3").unwrap().parse().unwrap();
135        let value = eval(ast).unwrap();
136        assert_eq!(value, EvalResult::Number(0.0))
137    }
138    #[test]
139    fn test_expr2() {
140        use crate::parsedie::parser::Parser;
141
142        let ast = Parser::new("3+2-1*5/4").unwrap().parse().unwrap();
143        let value = eval(ast).unwrap();
144        assert_eq!(value, EvalResult::Number(3.75))
145    }
146}