radiate_gp/ops/
expr.rs

1use crate::{Op, Tree, TreeNode};
2
3pub trait Expression<T> {
4    fn parse(expr: &str) -> Result<T, String>;
5}
6
7impl Expression<Tree<Op<f32>>> for Tree<Op<f32>> {
8    fn parse(expr: &str) -> Result<Tree<Op<f32>>, String> {
9        parse(expr).map(|node| Tree::new(node))
10    }
11}
12
13fn parse(expr: &str) -> Result<TreeNode<Op<f32>>, String> {
14    let tokens = tokenize(expr);
15    let mut pos = 0;
16    parse_expression(&tokens, &mut pos)
17}
18
19#[derive(Debug, Clone, PartialEq)]
20enum Token {
21    Number(f32),
22    Identifier(String, Option<usize>),
23    Plus,
24    Minus,
25    Multiply,
26    Divide,
27    Power,
28    LParen,
29    RParen,
30    EOF,
31}
32
33fn tokenize(expression: &str) -> Vec<Token> {
34    let mut tokens = Vec::new();
35    let mut chars = expression.chars().peekable();
36    let mut vars = Vec::new();
37
38    while let Some(&ch) = chars.peek() {
39        match ch {
40            ' ' | '\t' | '\n' => {
41                chars.next();
42            }
43            '0'..='9' | '.' => {
44                let mut num = String::new();
45                while let Some(&ch) = chars.peek() {
46                    if ch.is_ascii_digit() || ch == '.' {
47                        num.push(ch);
48                        chars.next();
49                    } else {
50                        break;
51                    }
52                }
53                tokens.push(Token::Number(num.parse().unwrap()));
54            }
55            'a'..='z' | 'A'..='Z' => {
56                let mut ident = String::new();
57                while let Some(&ch) = chars.peek() {
58                    if ch.is_alphanumeric() || ch == '_' {
59                        ident.push(ch);
60                        chars.next();
61                    } else {
62                        break;
63                    }
64                }
65                tokens.push(Token::Identifier(ident.clone(), None));
66                vars.push(ident);
67            }
68            '+' => {
69                chars.next();
70                tokens.push(Token::Plus);
71            }
72            '-' => {
73                chars.next();
74                tokens.push(Token::Minus);
75            }
76            '*' => {
77                chars.next();
78                tokens.push(Token::Multiply);
79            }
80            '/' => {
81                chars.next();
82                tokens.push(Token::Divide);
83            }
84            '^' => {
85                chars.next();
86                tokens.push(Token::Power);
87            }
88            '(' => {
89                chars.next();
90                tokens.push(Token::LParen);
91            }
92            ')' => {
93                chars.next();
94                tokens.push(Token::RParen);
95            }
96            _ => panic!("Unexpected character: {}", ch),
97        }
98    }
99
100    tokens.push(Token::EOF);
101    vars.dedup();
102    vars.sort();
103
104    for i in 0..tokens.len() {
105        if matches!(tokens[i], Token::Identifier(_, _)) {
106            let name = match &tokens[i] {
107                Token::Identifier(name, _) => name,
108                _ => unreachable!(),
109            };
110            let index = vars.iter().position(|v| v == name).unwrap();
111            tokens[i] = Token::Identifier(name.clone(), Some(index));
112        }
113    }
114
115    tokens
116}
117
118fn parse_expression(tokens: &[Token], pos: &mut usize) -> Result<TreeNode<Op<f32>>, String> {
119    let mut node = parse_term(tokens, pos)?;
120
121    while let Some(token) = tokens.get(*pos) {
122        match token {
123            Token::Plus | Token::Minus => {
124                let op = token.clone();
125                *pos += 1;
126                let right = parse_term(tokens, pos)?;
127                node = TreeNode::new(match op {
128                    Token::Plus => Op::add(),
129                    Token::Minus => Op::sub(),
130                    _ => unreachable!(),
131                })
132                .attach(node)
133                .attach(right);
134            }
135            _ => break,
136        }
137    }
138
139    Ok(node)
140}
141
142fn parse_term(tokens: &[Token], pos: &mut usize) -> Result<TreeNode<Op<f32>>, String> {
143    let mut node = parse_power(tokens, pos)?;
144
145    while let Some(token) = tokens.get(*pos) {
146        match token {
147            Token::Multiply | Token::Divide => {
148                let op = token.clone();
149                *pos += 1;
150                let right = parse_power(tokens, pos)?;
151                node = TreeNode::new(match op {
152                    Token::Multiply => Op::mul(),
153                    Token::Divide => Op::div(),
154                    _ => unreachable!(),
155                })
156                .attach(node)
157                .attach(right);
158            }
159            _ => break,
160        }
161    }
162
163    Ok(node)
164}
165
166fn parse_power(tokens: &[Token], pos: &mut usize) -> Result<TreeNode<Op<f32>>, String> {
167    let mut node = parse_factor(tokens, pos)?;
168
169    if let Some(Token::Power) = tokens.get(*pos) {
170        *pos += 1;
171        let right = parse_power(tokens, pos)?;
172        node = TreeNode::new(Op::pow()).attach(node).attach(right);
173    }
174
175    Ok(node)
176}
177
178fn parse_factor(tokens: &[Token], pos: &mut usize) -> Result<TreeNode<Op<f32>>, String> {
179    match tokens.get(*pos) {
180        Some(Token::Minus) => {
181            *pos += 1;
182            Ok(TreeNode::new(Op::neg()).attach(parse_factor(tokens, pos)?))
183        }
184        Some(Token::Plus) => {
185            *pos += 1;
186            parse_factor(tokens, pos)
187        }
188        Some(Token::Number(n)) => {
189            *pos += 1;
190            Ok(TreeNode::new(Op::constant(*n)))
191        }
192        Some(Token::Identifier(_, var)) => {
193            *pos += 1;
194            Ok(TreeNode::new(Op::var(var.unwrap())))
195        }
196        Some(Token::LParen) => {
197            *pos += 1;
198            let node = parse_expression(tokens, pos)?;
199            if let Some(Token::RParen) = tokens.get(*pos) {
200                *pos += 1;
201                Ok(node)
202            } else {
203                Err("Expected ')'".to_string())
204            }
205        }
206        token => Err(format!("Unexpected token: {:?}", token)),
207    }
208}
209
210#[cfg(test)]
211mod test {
212    use crate::{Eval, Tree, ops::expr::Expression};
213
214    #[test]
215    fn test_tokenize() {
216        let expr_str = "1 + 2 * (3 * 4)^5";
217        if let Ok(tree) = Tree::parse(expr_str) {
218            assert_eq!(tree.eval(&[]), 497665.0);
219        } else {
220            panic!("Failed to parse expression");
221        }
222    }
223
224    #[test]
225    fn test_tokenize_with_vars() {
226        let expr_str = "a + b * (c * d)^e";
227
228        if let Ok(tree) = Tree::parse(expr_str) {
229            assert_eq!(tree.eval(&[1.0, 2.0, 3.0, 4.0, 5.0]), 497665.0);
230        } else {
231            panic!("Failed to parse expression");
232        }
233    }
234
235    #[test]
236    fn test_tokenize_with_vars_and_negation() {
237        let expr_str = "5 - x * (34 * 3)^2";
238
239        if let Ok(tree) = Tree::parse(expr_str) {
240            assert_eq!(tree.eval(&[3.0]), -31207.0);
241        } else {
242            panic!("Failed to parse expression");
243        }
244    }
245
246    #[test]
247    fn test_tokenize_with_vars_and_negation_and_parens() {
248        let comp = |x: f32| 4.0 * x.powf(3.0) - 3.0 * x.powf(2.0) + x;
249
250        let expr_str = "4 * x^3 - 3 * x^2 + x";
251
252        if let Ok(tree) = Tree::parse(expr_str) {
253            let mut input = -1.0;
254            for _ in -10..10 {
255                input += 0.1;
256                let output = tree.eval(&[input]);
257                assert!((output - comp(input)).abs() < 0.0001);
258            }
259        } else {
260            panic!("Failed to parse expression");
261        }
262    }
263}