formulac/
lib.rs

1//! # formulac
2//!
3//! `formulac` is a Rust library for parsing and evaluating mathematical
4//! expressions with support for **complex numbers** and **extensible user-defined functions**.
5//!
6//! ## Overview
7//! - Parse and evaluate expressions containing real and imaginary numbers.
8//! - Use built-in operators, constants, and mathematical functions.
9//! - Register your own variables and functions.
10//! - Compile expressions into callable closures for repeated evaluation without re-parsing.
11//!
12//! Internally, expressions are first tokenized into lexeme,
13//! then converted to an AST using the Shunting-Yard algorithm,
14//! and finally compiled into Reverse Polish Notation (RPN) stack operations
15//! for fast repeated execution.
16//!
17//! ## Feature Highlights
18//! - **Complex number support** using [`num_complex::Complex<f64>`]
19//! - **User-defined functions and constants** via [`UserDefinedTable`]
20//! - **Variables and arguments** managed by [`Variables`]
21//! - **Operator precedence** and parentheses handling
22//! - **Efficient compiled closures** avoiding repeated parsing
23//!
24//! ## Example
25//! ```rust
26//! use num_complex::Complex;
27//! use formulac::{compile, Variables, UserDefinedTable};
28//!
29//! let mut vars = Variables::new();
30//! vars.insert(&[("a", Complex::new(3.0, 2.0))]);
31//!
32//! let users = UserDefinedTable::new();
33//! let expr = compile("sin(z) + a * cos(z)", &["z"], &vars, &users)
34//!     .expect("Failed to compile formula");
35//!
36//! let result = expr(&[Complex::new(1.0, 2.0)]);
37//! println!("Result = {}", result);
38//! ```
39//!
40//! ## Example: Retrieving All Names
41//! ```rust
42//! use formulac::parser::{constant, UnaryOperatorKind, BinaryOperatorKind, FunctionKind};
43//!
44//! // Constants
45//! let constant_names: Vec<&'static str> = constant::names();
46//! println!("Constants: {:?}", constant_names);
47//!
48//! // Unary operators
49//! let unary_names: Vec<&'static str> = UnaryOperatorKind::names();
50//! println!("Unary Operators: {:?}", unary_names);
51//!
52//! // Binary operators
53//! let binary_names: Vec<&'static str> = BinaryOperatorKind::names();
54//! println!("Binary Operators: {:?}", binary_names);
55//!
56//! // Functions
57//! let function_names: Vec<&'static str> = FunctionKind::names();
58//! println!("Functions: {:?}", function_names);
59//! ```
60//!
61//! ## When to Use
62//! Use `formulac` when you need:
63//! - Fast repeated evaluation of mathematical formulas
64//! - Complex number support in expressions
65//! - Runtime extensibility via custom functions or constants
66//!
67//! ## License
68//! Licensed under either **MIT** or **Apache-2.0** at your option.
69
70mod lexer;
71pub mod parser;
72pub mod variable;
73
74use num_complex::Complex;
75use crate::{parser::Token};
76use crate::{variable::FunctionCall};
77
78pub type UserDefinedFunction = variable::UserDefinedFunction;
79pub type UserDefinedTable = variable::UserDefinedTable;
80pub type Variables = variable::Variables;
81
82/// Compiles a mathematical expression into an executable closure.
83///
84/// This function parses a formula string into an abstract syntax tree (AST),
85/// simplifies it, and then compiles it into a list of stack operations
86/// (in Reverse Polish Notation). The result is returned as a closure that
87/// can be called multiple times with different argument values without
88/// re-parsing the formula.
89///
90/// # Parameters
91/// - `formula`: A string slice containing the mathematical expression to compile.
92/// - `arg_names`: A slice of argument names (`&str`) that the formula depends on.
93///   The closure returned will expect argument values in the same order.
94/// - `vars`: A [`Variables`] table mapping variable names
95///   to constant values available in the formula.
96/// - `users`: A [`UserDefinedTable`] containing
97///   any user-defined functions or constants available in the formula.
98///
99/// # Returns
100/// On success, returns a closure of type:
101///
102/// ```rust,ignore
103/// Fn(&[Complex<f64>]) -> Complex<f64>
104/// ```
105///
106/// - The closure takes a slice of complex argument values corresponding to `arg_names`.
107/// - Returns `Complex<f64>` if evaluation succeeds.
108/// - Panics if the number of arguments provided does not match `arg_names.len()`.
109///   So check it with debug build.
110///
111/// On failure, returns an error string describing the parsing or compilation error.
112///
113/// # Example
114/// ```rust
115/// use num_complex::Complex;
116/// use formulac::{compile, Variables, UserDefinedTable};
117///
118/// let mut vars = Variables::new();
119/// vars.insert(&[("a", Complex::new(3.0, 2.0))]);
120///
121/// let users = UserDefinedTable::new();
122/// let expr = compile("sin(z) + a * cos(z)", &["z"], &vars, &users)
123///     .expect("Failed to compile formula");
124///
125/// let result = expr(&[Complex::new(1.0, 2.0)]);
126/// println!("Result = {}", result);
127/// ```
128///
129/// # Notes
130/// - The formula string must be a valid expression using supported operators,
131///   variables, and functions.
132/// - Argument names are resolved in the order provided by `arg_names`.
133/// - This function does not evaluate immediately; instead, it produces
134///   a reusable compiled closure for efficient repeated evaluation.
135pub fn compile(
136    formula: &str,
137    arg_names: &[&str],
138    vars: &Variables,
139    users: &UserDefinedTable
140) -> Result<impl Fn(&[Complex<f64>]) -> Complex<f64>, String>
141{
142    let lexemes = lexer::from(formula);
143    let tokens = parser::AstNode::from(&lexemes, arg_names, vars, users)?
144        .simplify().compile();
145
146    let expected_arity = arg_names.len();
147    let func = move |arg_values: &[Complex<f64>]| {
148        // check arity only debug build
149        debug_assert_eq!(arg_values.len(), expected_arity);
150
151        let mut stack: Vec<Complex<f64>> = Vec::new();
152        for token in tokens.iter() {
153            match token {
154                Token::Number(val) => stack.push(*val),
155                Token::Argument(idx) => stack.push(arg_values[*idx]),
156                Token::UnaryOperator(oper) => {
157                    let expr = stack.pop().unwrap();
158                    stack.push(oper.apply(expr));
159                },
160                Token::BinaryOperator(oper) => {
161                    let r = stack.pop().unwrap();
162                    let l = stack.pop().unwrap();
163                    stack.push(oper.apply(l, r));
164                },
165                Token::Function(func) => {
166                    let n = func.arity();
167                    let mut args: Vec<Complex<f64>> = Vec::with_capacity(n);
168                    args.resize(n, Complex::new(0.0, 0.0));
169
170                    for i in (0..n).rev() {
171                        args[i] = stack.pop().unwrap();
172                    }
173                    stack.push(func.apply(&args));
174                },
175                Token::UserFunction(func) => {
176                    let n = func.arity();
177                    let mut args: Vec<Complex<f64>> = Vec::with_capacity(n);
178                    args.resize(n, Complex::new(0.0, 0.0));
179
180                    for i in (0..n).rev() {
181                        args[i] = stack.pop().unwrap();
182                    }
183                    stack.push(func.apply(&args));
184                },
185                _ => unreachable!("Invalid tokens found: use compiled tokens"),
186            }
187        }
188
189        stack.pop().unwrap()
190    };
191
192    Ok(func)
193}
194
195#[cfg(test)]
196mod compile_test {
197    use super::*;
198    use num_complex::{Complex};
199    use approx::assert_abs_diff_eq;
200
201    #[test]
202    fn test_constant_number() {
203        let vars = Variables::new();
204        let users = UserDefinedTable::new();
205        let f = compile("42", &[], &vars, &users).unwrap();
206        let result = f(&[]);
207        assert_eq!(result, Complex::new(42.0, 0.0));
208    }
209
210    #[test]
211    fn test_constant_str() {
212        let vars = Variables::new();
213        let users = UserDefinedTable::new();
214        let f = compile("PI", &[], &vars, &users).unwrap();
215        let result = f(&[]);
216        assert_eq!(result, Complex::from(std::f64::consts::PI));
217    }
218
219    #[test]
220    fn test_argument() {
221        let vars = Variables::new();
222        let users = UserDefinedTable::new();
223        let f = compile("x", &["x"], &vars, &users).unwrap();
224        let result = f(&[Complex::new(3.0, 0.0)]);
225        assert_eq!(result, Complex::new(3.0, 0.0));
226    }
227
228    #[test]
229    fn test_addition() {
230        let vars = Variables::new();
231        let users = UserDefinedTable::new();
232        let f = compile("x + y", &["x", "y"], &vars, &users).unwrap();
233        let x = Complex::new(2.0, 1.0);
234        let y = Complex::new(3.0, 5.0);
235        let result = f(&[x, y]);
236        assert_abs_diff_eq!(result.re, (x + y).re, epsilon=1.0e-12);
237        assert_abs_diff_eq!(result.im, (x + y).im, epsilon=1.0e-12);
238    }
239
240    #[test]
241    fn test_nested_expression() {
242        let vars = Variables::new();
243        let users = UserDefinedTable::new();
244        let f = compile("sin(x + 1)", &["x"], &vars, &users).unwrap();
245        let result = f(&[Complex::new(0.0, 1.0)]);
246        let expected = Complex::new(1.0, 1.0).sin();
247        assert_abs_diff_eq!(result.re, expected.re, epsilon=1.0e-12);
248        assert_abs_diff_eq!(result.im, expected.im, epsilon=1.0e-12);
249    }
250
251    #[test]
252    fn test_binary_operator_precedence() {
253        let vars = Variables::new();
254        let users = UserDefinedTable::new();
255        let f = compile("2 + 3 * 4", &[], &vars, &users).unwrap();
256        let result = f(&[]);
257        let expected = Complex::from(2.0 + 3.0 * 4.0);
258        assert_abs_diff_eq!(result.re, expected.re, epsilon=1.0e-12);
259        assert_abs_diff_eq!(result.im, expected.im, epsilon=1.0e-12);
260    }
261
262    #[test]
263    fn test_function_with_two_args() {
264        let vars = Variables::new();
265        let users = UserDefinedTable::new();
266        let f = compile("pow(a, b)", &["a", "b"], &vars, &users).unwrap();
267        let a = Complex::new(2.0, 1.0);
268        let b = Complex::new(-2.0, 3.0);
269        let result = f(&[a, b]);
270        let expected = a.powc(b);
271        assert_abs_diff_eq!(result.re, expected.re, epsilon=1.0e-12);
272        assert_abs_diff_eq!(result.im, expected.im, epsilon=1.0e-12);
273    }
274
275    #[test]
276    fn test_differentiate_without_order() {
277        let vars = Variables::new();
278        let users = UserDefinedTable::new();
279        let f = compile("diff(x^2, x)", &["x"], &vars, &users).unwrap();
280        let x = Complex::new(2.0, 1.0);
281        let result = f(&[x]);
282        let expected = 2.0 * x;
283        assert_abs_diff_eq!(result.re, expected.re, epsilon=1.0e-12);
284        assert_abs_diff_eq!(result.im, expected.im, epsilon=1.0e-12);
285    }
286
287    #[test]
288    fn test_differentiate_with_order() {
289        let vars = Variables::new();
290        let users = UserDefinedTable::new();
291        let f = compile("diff(x^3, x, 2)", &["x"], &vars, &users).unwrap();
292        let x = Complex::new(2.0, 1.0);
293        let result = f(&[x]);
294        let expected = 6.0 * x;
295        assert_abs_diff_eq!(result.re, expected.re, epsilon=1.0e-12);
296        assert_abs_diff_eq!(result.im, expected.im, epsilon=1.0e-12);
297    }
298
299    #[test]
300    fn test_differentiate_with_userdefinedfunction() {
301        let mut users = UserDefinedTable::new();
302
303        // Define f(x) = x^2
304        let func = UserDefinedFunction::new(
305            "f",
306            |args: &[Complex<f64>]| args[0] * args[0],
307            1,
308        ).with_derivative(
309            // derivative f'(x) = 2x
310            vec![|args: &[Complex<f64>]| Complex::new(2.0, 0.0) * args[0]],
311        );
312        users.register("f", func);
313
314        let vars = Variables::new();
315        let expr = compile("diff(f(x), x)", &["x"], &vars, &users).unwrap();
316
317        let result = expr(&[Complex::new(3.0, 0.0)]); // evaluates f'(3) = 6
318        assert_abs_diff_eq!(result.re, 6.0, epsilon=1.0e-12);
319        assert_abs_diff_eq!(result.im, 0.0, epsilon=1.0e-12);
320    }
321
322    #[test]
323    fn test_differentiate_with_partial_derivative() {
324        let mut users = UserDefinedTable::new();
325
326        // Define g(x, y) = x^2 * y + y^3
327        let func = UserDefinedFunction::new(
328            "g",
329            |args: &[Complex<f64>]| args[0]*args[0]*args[1] + args[1]*args[1]*args[1],
330            2,
331        ).with_derivative(vec![
332            // partial derivative w.r.t x: ∂g/∂x = 2*x*y
333            |args: &[Complex<f64>]| Complex::new(2.0, 0.0) * args[0] * args[1],
334            // partial derivative w.r.t y: ∂g/∂y = x^2 + 3*y^2
335            |args: &[Complex<f64>]| args[0]*args[0] + Complex::new(3.0, 0.0)*args[1]*args[1],
336        ]);
337        users.register("g", func);
338
339        let vars = Variables::new();
340
341        let x = Complex::new(2.0, 0.0);
342        let y = Complex::new(3.0, 0.0);
343
344        let expr_dx = compile("diff(g(x, y), x)", &["x", "y"], &vars, &users).unwrap();
345        let result_dx = expr_dx(&[x, y]);
346        let expect_dx = 2.0 * x * y;
347        assert_abs_diff_eq!(result_dx.re, expect_dx.re, epsilon=1.0e-12);
348        assert_abs_diff_eq!(result_dx.im, expect_dx.im, epsilon=1.0e-12);
349
350        let expr_dy = compile("diff(g(x, y), y)", &["x", "y"], &vars, &users).unwrap();
351        let result_dy = expr_dy(&[Complex::new(2.0, 0.0), Complex::new(3.0, 0.0)]);
352        let expect_dy = x * x + 3.0 * y * y;
353        assert_abs_diff_eq!(result_dy.re, expect_dy.re, epsilon=1.0e-12);
354        assert_abs_diff_eq!(result_dy.im, expect_dy.im, epsilon=1.0e-12);
355    }
356
357    #[test]
358    #[should_panic]
359    fn test_too_less_args_length() {
360        let vars = Variables::new();
361        let users = UserDefinedTable::new();
362        let f = compile("x + 1", &["x"], &vars, &users).unwrap();
363        f(&[]);
364        // Too much arguments
365        f(&[Complex::new(1.0, 0.0), Complex::new(2.0, 0.0)]);
366    }
367
368    #[cfg(debug_assertions)]
369    #[test]
370    #[should_panic]
371    fn test_too_much_args_length() {
372        let vars = Variables::new();
373        let users = UserDefinedTable::new();
374        let f = compile("x + 1", &["x"], &vars, &users).unwrap();
375        f(&[Complex::new(1.0, 0.0), Complex::new(2.0, 0.0)]);
376    }
377
378    #[cfg(not(debug_assertions))]
379    #[test]
380    fn test_too_much_args_length() {
381        let vars = Variables::new();
382        let users = UserDefinedTable::new();
383        let f = compile("x + 1", &["x"], &vars, &users).unwrap();
384        let result = f(&[Complex::new(1.0, 0.0), Complex::new(2.0, 0.0)]);
385        assert_abs_diff_eq!(result.re, 2.0, epsilon=1.0e-12);
386        assert_abs_diff_eq!(result.im, 0.0, epsilon=1.0e-12);
387    }
388
389
390    #[test]
391    fn test_variables() {
392        let a = Complex::new(2.0, 1.0);
393        let b = Complex::new(-4.0, 2.0);
394        let x = Complex::new(1.0, 0.0);
395        let mut vars = Variables::new();
396        let users = UserDefinedTable::new();
397        vars.insert(&[("a", a), ("b", b),]);
398
399        let f = compile("a * x + b", &["x"], &vars, &users).unwrap();
400        let result = f(&[x]);
401        let expected = a * x + b;
402        assert_abs_diff_eq!(result.re, expected.re, epsilon=1.0e-12);
403        assert_abs_diff_eq!(result.im, expected.im, epsilon=1.0e-12);
404    }
405}
406
407#[cfg(test)]
408mod issue_test {
409    use super::*;
410    use num_complex::{Complex};
411    use approx::assert_abs_diff_eq;
412
413    #[test]
414    /// It appears as if parenthesis are not effecting function call precedence in the way
415    /// that the example code would have me believe. I.e f(x) + y is being parsed as f(x +y)
416    /// # This issue was reported at v0.5.0, and resolved in v0.5.1
417    fn test_issue_1() {
418        let vars = Variables::new();
419        let users = UserDefinedTable::new();
420
421        let z = Complex::new(1.0, 3.0);
422
423        let expr_1 = compile("sin(z) + z",&["z"],&vars,&users)
424            .expect("failed to compile formula");
425        let result_1 = expr_1(&[z]);
426        let expect_1 = z.sin() + z;
427
428        assert_abs_diff_eq!(result_1.re, expect_1.re, epsilon=1.0e-12);
429        assert_abs_diff_eq!(result_1.im, expect_1.im, epsilon=1.0e-12);
430
431        let expr_2 = compile("sin(z + z)",&["z"],&vars,&users)
432            .expect("failed to compile formula");
433        let result_2 = expr_2(&[z]);
434        let expect_2 = (z+z).sin();
435        assert_abs_diff_eq!(result_2.re, expect_2.re, epsilon=1.0e-12);
436        assert_abs_diff_eq!(result_2.im, expect_2.im, epsilon=1.0e-12);
437
438        let expr_3 = compile("(sin(z)) + z",&["z"],&vars,&users)
439            .expect("failed to compile formula");
440        let result_3 = expr_3(&[z]);
441        let expect_3 = (z.sin()) + z;
442        assert_abs_diff_eq!(result_3.re, expect_3.re, epsilon=1.0e-12);
443        assert_abs_diff_eq!(result_3.im, expect_3.im, epsilon=1.0e-12);
444    }
445}