1use crate::{Op, Tree, TreeNode};
2
3pub trait Expression<T> {
4 fn parse(expr: &str) -> Result<T, String>;
5}
6
7impl Expression<Tree<Op<f32>>> for Tree<Op<f32>> {
8 fn parse(expr: &str) -> Result<Tree<Op<f32>>, String> {
9 parse(expr).map(|node| Tree::new(node))
10 }
11}
12
13fn parse(expr: &str) -> Result<TreeNode<Op<f32>>, String> {
14 let tokens = tokenize(expr);
15 let mut pos = 0;
16 parse_expression(&tokens, &mut pos)
17}
18
19#[derive(Debug, Clone, PartialEq)]
20enum Token {
21 Number(f32),
22 Identifier(String, Option<usize>),
23 Plus,
24 Minus,
25 Multiply,
26 Divide,
27 Power,
28 LParen,
29 RParen,
30 EOF,
31}
32
33fn tokenize(expression: &str) -> Vec<Token> {
34 let mut tokens = Vec::new();
35 let mut chars = expression.chars().peekable();
36 let mut vars = Vec::new();
37
38 while let Some(&ch) = chars.peek() {
39 match ch {
40 ' ' | '\t' | '\n' => {
41 chars.next();
42 }
43 '0'..='9' | '.' => {
44 let mut num = String::new();
45 while let Some(&ch) = chars.peek() {
46 if ch.is_ascii_digit() || ch == '.' {
47 num.push(ch);
48 chars.next();
49 } else {
50 break;
51 }
52 }
53 tokens.push(Token::Number(num.parse().unwrap()));
54 }
55 'a'..='z' | 'A'..='Z' => {
56 let mut ident = String::new();
57 while let Some(&ch) = chars.peek() {
58 if ch.is_alphanumeric() || ch == '_' {
59 ident.push(ch);
60 chars.next();
61 } else {
62 break;
63 }
64 }
65 tokens.push(Token::Identifier(ident.clone(), None));
66 vars.push(ident);
67 }
68 '+' => {
69 chars.next();
70 tokens.push(Token::Plus);
71 }
72 '-' => {
73 chars.next();
74 tokens.push(Token::Minus);
75 }
76 '*' => {
77 chars.next();
78 tokens.push(Token::Multiply);
79 }
80 '/' => {
81 chars.next();
82 tokens.push(Token::Divide);
83 }
84 '^' => {
85 chars.next();
86 tokens.push(Token::Power);
87 }
88 '(' => {
89 chars.next();
90 tokens.push(Token::LParen);
91 }
92 ')' => {
93 chars.next();
94 tokens.push(Token::RParen);
95 }
96 _ => panic!("Unexpected character: {}", ch),
97 }
98 }
99
100 tokens.push(Token::EOF);
101 vars.dedup();
102 vars.sort();
103
104 for i in 0..tokens.len() {
105 if matches!(tokens[i], Token::Identifier(_, _)) {
106 let name = match &tokens[i] {
107 Token::Identifier(name, _) => name,
108 _ => unreachable!(),
109 };
110 let index = vars.iter().position(|v| v == name).unwrap();
111 tokens[i] = Token::Identifier(name.clone(), Some(index));
112 }
113 }
114
115 tokens
116}
117
118fn parse_expression(tokens: &[Token], pos: &mut usize) -> Result<TreeNode<Op<f32>>, String> {
119 let mut node = parse_term(tokens, pos)?;
120
121 while let Some(token) = tokens.get(*pos) {
122 match token {
123 Token::Plus | Token::Minus => {
124 let op = token.clone();
125 *pos += 1;
126 let right = parse_term(tokens, pos)?;
127 node = TreeNode::new(match op {
128 Token::Plus => Op::add(),
129 Token::Minus => Op::sub(),
130 _ => unreachable!(),
131 })
132 .attach(node)
133 .attach(right);
134 }
135 _ => break,
136 }
137 }
138
139 Ok(node)
140}
141
142fn parse_term(tokens: &[Token], pos: &mut usize) -> Result<TreeNode<Op<f32>>, String> {
143 let mut node = parse_power(tokens, pos)?;
144
145 while let Some(token) = tokens.get(*pos) {
146 match token {
147 Token::Multiply | Token::Divide => {
148 let op = token.clone();
149 *pos += 1;
150 let right = parse_power(tokens, pos)?;
151 node = TreeNode::new(match op {
152 Token::Multiply => Op::mul(),
153 Token::Divide => Op::div(),
154 _ => unreachable!(),
155 })
156 .attach(node)
157 .attach(right);
158 }
159 _ => break,
160 }
161 }
162
163 Ok(node)
164}
165
166fn parse_power(tokens: &[Token], pos: &mut usize) -> Result<TreeNode<Op<f32>>, String> {
167 let mut node = parse_factor(tokens, pos)?;
168
169 if let Some(Token::Power) = tokens.get(*pos) {
170 *pos += 1;
171 let right = parse_power(tokens, pos)?;
172 node = TreeNode::new(Op::pow()).attach(node).attach(right);
173 }
174
175 Ok(node)
176}
177
178fn parse_factor(tokens: &[Token], pos: &mut usize) -> Result<TreeNode<Op<f32>>, String> {
179 match tokens.get(*pos) {
180 Some(Token::Minus) => {
181 *pos += 1;
182 Ok(TreeNode::new(Op::neg()).attach(parse_factor(tokens, pos)?))
183 }
184 Some(Token::Plus) => {
185 *pos += 1;
186 parse_factor(tokens, pos)
187 }
188 Some(Token::Number(n)) => {
189 *pos += 1;
190 Ok(TreeNode::new(Op::constant(*n)))
191 }
192 Some(Token::Identifier(_, var)) => {
193 *pos += 1;
194 Ok(TreeNode::new(Op::var(var.unwrap())))
195 }
196 Some(Token::LParen) => {
197 *pos += 1;
198 let node = parse_expression(tokens, pos)?;
199 if let Some(Token::RParen) = tokens.get(*pos) {
200 *pos += 1;
201 Ok(node)
202 } else {
203 Err("Expected ')'".to_string())
204 }
205 }
206 token => Err(format!("Unexpected token: {:?}", token)),
207 }
208}
209
210#[cfg(test)]
211mod test {
212 use crate::{Eval, Tree, ops::expr::Expression};
213
214 #[test]
215 fn test_tokenize() {
216 let expr_str = "1 + 2 * (3 * 4)^5";
217 if let Ok(tree) = Tree::parse(expr_str) {
218 assert_eq!(tree.eval(&[]), 497665.0);
219 } else {
220 panic!("Failed to parse expression");
221 }
222 }
223
224 #[test]
225 fn test_tokenize_with_vars() {
226 let expr_str = "a + b * (c * d)^e";
227
228 if let Ok(tree) = Tree::parse(expr_str) {
229 assert_eq!(tree.eval(&[1.0, 2.0, 3.0, 4.0, 5.0]), 497665.0);
230 } else {
231 panic!("Failed to parse expression");
232 }
233 }
234
235 #[test]
236 fn test_tokenize_with_vars_and_negation() {
237 let expr_str = "5 - x * (34 * 3)^2";
238
239 if let Ok(tree) = Tree::parse(expr_str) {
240 assert_eq!(tree.eval(&[3.0]), -31207.0);
241 } else {
242 panic!("Failed to parse expression");
243 }
244 }
245
246 #[test]
247 fn test_tokenize_with_vars_and_negation_and_parens() {
248 let comp = |x: f32| 4.0 * x.powf(3.0) - 3.0 * x.powf(2.0) + x;
249
250 let expr_str = "4 * x^3 - 3 * x^2 + x";
251
252 if let Ok(tree) = Tree::parse(expr_str) {
253 let mut input = -1.0;
254 for _ in -10..10 {
255 input += 0.1;
256 let output = tree.eval(&[input]);
257 assert!((output - comp(input)).abs() < 0.0001);
258 }
259 } else {
260 panic!("Failed to parse expression");
261 }
262 }
263}