1mod lexer;
71pub mod parser;
72pub mod variable;
73
74use num_complex::Complex;
75use crate::{parser::Token};
76use crate::{variable::FunctionCall};
77
78pub type UserDefinedFunction = variable::UserDefinedFunction;
79pub type UserDefinedTable = variable::UserDefinedTable;
80pub type Variables = variable::Variables;
81
82pub fn compile(
136 formula: &str,
137 arg_names: &[&str],
138 vars: &Variables,
139 users: &UserDefinedTable
140) -> Result<impl Fn(&[Complex<f64>]) -> Complex<f64>, String>
141{
142 let lexemes = lexer::from(formula);
143 let tokens = parser::AstNode::from(&lexemes, arg_names, vars, users)?
144 .simplify().compile();
145
146 let expected_arity = arg_names.len();
147 let func = move |arg_values: &[Complex<f64>]| {
148 debug_assert_eq!(arg_values.len(), expected_arity);
150
151 let mut stack: Vec<Complex<f64>> = Vec::new();
152 for token in tokens.iter() {
153 match token {
154 Token::Number(val) => stack.push(*val),
155 Token::Argument(idx) => stack.push(arg_values[*idx]),
156 Token::UnaryOperator(oper) => {
157 let expr = stack.pop().unwrap();
158 stack.push(oper.apply(expr));
159 },
160 Token::BinaryOperator(oper) => {
161 let r = stack.pop().unwrap();
162 let l = stack.pop().unwrap();
163 stack.push(oper.apply(l, r));
164 },
165 Token::Function(func) => {
166 let n = func.arity();
167 let mut args: Vec<Complex<f64>> = Vec::with_capacity(n);
168 args.resize(n, Complex::new(0.0, 0.0));
169
170 for i in (0..n).rev() {
171 args[i] = stack.pop().unwrap();
172 }
173 stack.push(func.apply(&args));
174 },
175 Token::UserFunction(func) => {
176 let n = func.arity();
177 let mut args: Vec<Complex<f64>> = Vec::with_capacity(n);
178 args.resize(n, Complex::new(0.0, 0.0));
179
180 for i in (0..n).rev() {
181 args[i] = stack.pop().unwrap();
182 }
183 stack.push(func.apply(&args));
184 },
185 _ => unreachable!("Invalid tokens found: use compiled tokens"),
186 }
187 }
188
189 stack.pop().unwrap()
190 };
191
192 Ok(func)
193}
194
195#[cfg(test)]
196mod compile_test {
197 use super::*;
198 use num_complex::{Complex};
199 use approx::assert_abs_diff_eq;
200
201 #[test]
202 fn test_constant_number() {
203 let vars = Variables::new();
204 let users = UserDefinedTable::new();
205 let f = compile("42", &[], &vars, &users).unwrap();
206 let result = f(&[]);
207 assert_eq!(result, Complex::new(42.0, 0.0));
208 }
209
210 #[test]
211 fn test_constant_str() {
212 let vars = Variables::new();
213 let users = UserDefinedTable::new();
214 let f = compile("PI", &[], &vars, &users).unwrap();
215 let result = f(&[]);
216 assert_eq!(result, Complex::from(std::f64::consts::PI));
217 }
218
219 #[test]
220 fn test_argument() {
221 let vars = Variables::new();
222 let users = UserDefinedTable::new();
223 let f = compile("x", &["x"], &vars, &users).unwrap();
224 let result = f(&[Complex::new(3.0, 0.0)]);
225 assert_eq!(result, Complex::new(3.0, 0.0));
226 }
227
228 #[test]
229 fn test_addition() {
230 let vars = Variables::new();
231 let users = UserDefinedTable::new();
232 let f = compile("x + y", &["x", "y"], &vars, &users).unwrap();
233 let x = Complex::new(2.0, 1.0);
234 let y = Complex::new(3.0, 5.0);
235 let result = f(&[x, y]);
236 assert_abs_diff_eq!(result.re, (x + y).re, epsilon=1.0e-12);
237 assert_abs_diff_eq!(result.im, (x + y).im, epsilon=1.0e-12);
238 }
239
240 #[test]
241 fn test_nested_expression() {
242 let vars = Variables::new();
243 let users = UserDefinedTable::new();
244 let f = compile("sin(x + 1)", &["x"], &vars, &users).unwrap();
245 let result = f(&[Complex::new(0.0, 1.0)]);
246 let expected = Complex::new(1.0, 1.0).sin();
247 assert_abs_diff_eq!(result.re, expected.re, epsilon=1.0e-12);
248 assert_abs_diff_eq!(result.im, expected.im, epsilon=1.0e-12);
249 }
250
251 #[test]
252 fn test_binary_operator_precedence() {
253 let vars = Variables::new();
254 let users = UserDefinedTable::new();
255 let f = compile("2 + 3 * 4", &[], &vars, &users).unwrap();
256 let result = f(&[]);
257 let expected = Complex::from(2.0 + 3.0 * 4.0);
258 assert_abs_diff_eq!(result.re, expected.re, epsilon=1.0e-12);
259 assert_abs_diff_eq!(result.im, expected.im, epsilon=1.0e-12);
260 }
261
262 #[test]
263 fn test_function_with_two_args() {
264 let vars = Variables::new();
265 let users = UserDefinedTable::new();
266 let f = compile("pow(a, b)", &["a", "b"], &vars, &users).unwrap();
267 let a = Complex::new(2.0, 1.0);
268 let b = Complex::new(-2.0, 3.0);
269 let result = f(&[a, b]);
270 let expected = a.powc(b);
271 assert_abs_diff_eq!(result.re, expected.re, epsilon=1.0e-12);
272 assert_abs_diff_eq!(result.im, expected.im, epsilon=1.0e-12);
273 }
274
275 #[test]
276 fn test_differentiate_without_order() {
277 let vars = Variables::new();
278 let users = UserDefinedTable::new();
279 let f = compile("diff(x^2, x)", &["x"], &vars, &users).unwrap();
280 let x = Complex::new(2.0, 1.0);
281 let result = f(&[x]);
282 let expected = 2.0 * x;
283 assert_abs_diff_eq!(result.re, expected.re, epsilon=1.0e-12);
284 assert_abs_diff_eq!(result.im, expected.im, epsilon=1.0e-12);
285 }
286
287 #[test]
288 fn test_differentiate_with_order() {
289 let vars = Variables::new();
290 let users = UserDefinedTable::new();
291 let f = compile("diff(x^3, x, 2)", &["x"], &vars, &users).unwrap();
292 let x = Complex::new(2.0, 1.0);
293 let result = f(&[x]);
294 let expected = 6.0 * x;
295 assert_abs_diff_eq!(result.re, expected.re, epsilon=1.0e-12);
296 assert_abs_diff_eq!(result.im, expected.im, epsilon=1.0e-12);
297 }
298
299 #[test]
300 fn test_differentiate_with_userdefinedfunction() {
301 let mut users = UserDefinedTable::new();
302
303 let func = UserDefinedFunction::new(
305 "f",
306 |args: &[Complex<f64>]| args[0] * args[0],
307 1,
308 ).with_derivative(
309 vec![|args: &[Complex<f64>]| Complex::new(2.0, 0.0) * args[0]],
311 );
312 users.register("f", func);
313
314 let vars = Variables::new();
315 let expr = compile("diff(f(x), x)", &["x"], &vars, &users).unwrap();
316
317 let result = expr(&[Complex::new(3.0, 0.0)]); assert_abs_diff_eq!(result.re, 6.0, epsilon=1.0e-12);
319 assert_abs_diff_eq!(result.im, 0.0, epsilon=1.0e-12);
320 }
321
322 #[test]
323 fn test_differentiate_with_partial_derivative() {
324 let mut users = UserDefinedTable::new();
325
326 let func = UserDefinedFunction::new(
328 "g",
329 |args: &[Complex<f64>]| args[0]*args[0]*args[1] + args[1]*args[1]*args[1],
330 2,
331 ).with_derivative(vec![
332 |args: &[Complex<f64>]| Complex::new(2.0, 0.0) * args[0] * args[1],
334 |args: &[Complex<f64>]| args[0]*args[0] + Complex::new(3.0, 0.0)*args[1]*args[1],
336 ]);
337 users.register("g", func);
338
339 let vars = Variables::new();
340
341 let x = Complex::new(2.0, 0.0);
342 let y = Complex::new(3.0, 0.0);
343
344 let expr_dx = compile("diff(g(x, y), x)", &["x", "y"], &vars, &users).unwrap();
345 let result_dx = expr_dx(&[x, y]);
346 let expect_dx = 2.0 * x * y;
347 assert_abs_diff_eq!(result_dx.re, expect_dx.re, epsilon=1.0e-12);
348 assert_abs_diff_eq!(result_dx.im, expect_dx.im, epsilon=1.0e-12);
349
350 let expr_dy = compile("diff(g(x, y), y)", &["x", "y"], &vars, &users).unwrap();
351 let result_dy = expr_dy(&[Complex::new(2.0, 0.0), Complex::new(3.0, 0.0)]);
352 let expect_dy = x * x + 3.0 * y * y;
353 assert_abs_diff_eq!(result_dy.re, expect_dy.re, epsilon=1.0e-12);
354 assert_abs_diff_eq!(result_dy.im, expect_dy.im, epsilon=1.0e-12);
355 }
356
357 #[test]
358 #[should_panic]
359 fn test_too_less_args_length() {
360 let vars = Variables::new();
361 let users = UserDefinedTable::new();
362 let f = compile("x + 1", &["x"], &vars, &users).unwrap();
363 f(&[]);
364 f(&[Complex::new(1.0, 0.0), Complex::new(2.0, 0.0)]);
366 }
367
368 #[cfg(debug_assertions)]
369 #[test]
370 #[should_panic]
371 fn test_too_much_args_length() {
372 let vars = Variables::new();
373 let users = UserDefinedTable::new();
374 let f = compile("x + 1", &["x"], &vars, &users).unwrap();
375 f(&[Complex::new(1.0, 0.0), Complex::new(2.0, 0.0)]);
376 }
377
378 #[cfg(not(debug_assertions))]
379 #[test]
380 fn test_too_much_args_length() {
381 let vars = Variables::new();
382 let users = UserDefinedTable::new();
383 let f = compile("x + 1", &["x"], &vars, &users).unwrap();
384 let result = f(&[Complex::new(1.0, 0.0), Complex::new(2.0, 0.0)]);
385 assert_abs_diff_eq!(result.re, 2.0, epsilon=1.0e-12);
386 assert_abs_diff_eq!(result.im, 0.0, epsilon=1.0e-12);
387 }
388
389
390 #[test]
391 fn test_variables() {
392 let a = Complex::new(2.0, 1.0);
393 let b = Complex::new(-4.0, 2.0);
394 let x = Complex::new(1.0, 0.0);
395 let mut vars = Variables::new();
396 let users = UserDefinedTable::new();
397 vars.insert(&[("a", a), ("b", b),]);
398
399 let f = compile("a * x + b", &["x"], &vars, &users).unwrap();
400 let result = f(&[x]);
401 let expected = a * x + b;
402 assert_abs_diff_eq!(result.re, expected.re, epsilon=1.0e-12);
403 assert_abs_diff_eq!(result.im, expected.im, epsilon=1.0e-12);
404 }
405}
406
407#[cfg(test)]
408mod issue_test {
409 use super::*;
410 use num_complex::{Complex};
411 use approx::assert_abs_diff_eq;
412
413 #[test]
414 fn test_issue_1() {
418 let vars = Variables::new();
419 let users = UserDefinedTable::new();
420
421 let z = Complex::new(1.0, 3.0);
422
423 let expr_1 = compile("sin(z) + z",&["z"],&vars,&users)
424 .expect("failed to compile formula");
425 let result_1 = expr_1(&[z]);
426 let expect_1 = z.sin() + z;
427
428 assert_abs_diff_eq!(result_1.re, expect_1.re, epsilon=1.0e-12);
429 assert_abs_diff_eq!(result_1.im, expect_1.im, epsilon=1.0e-12);
430
431 let expr_2 = compile("sin(z + z)",&["z"],&vars,&users)
432 .expect("failed to compile formula");
433 let result_2 = expr_2(&[z]);
434 let expect_2 = (z+z).sin();
435 assert_abs_diff_eq!(result_2.re, expect_2.re, epsilon=1.0e-12);
436 assert_abs_diff_eq!(result_2.im, expect_2.im, epsilon=1.0e-12);
437
438 let expr_3 = compile("(sin(z)) + z",&["z"],&vars,&users)
439 .expect("failed to compile formula");
440 let result_3 = expr_3(&[z]);
441 let expect_3 = (z.sin()) + z;
442 assert_abs_diff_eq!(result_3.re, expect_3.re, epsilon=1.0e-12);
443 assert_abs_diff_eq!(result_3.im, expect_3.im, epsilon=1.0e-12);
444 }
445}