expr_solver/
symbol.rs

1use rust_decimal::Decimal;
2use rust_decimal::prelude::*;
3use std::borrow::Cow;
4use std::panic;
5use thiserror::Error;
6
7/// Errors that can occur during function evaluation.
8#[derive(Error, Debug, Clone)]
9pub enum FuncError {
10    #[error("Conversion error: failed to convert Decimal to f64")]
11    DecimalToF64Conversion,
12    #[error("Conversion error: failed to convert f64 result back to Decimal")]
13    F64ToDecimalConversion,
14    #[error("Square root of negative number: {value}")]
15    NegativeSqrt { value: Decimal },
16    #[error("Domain error in function '{function}': invalid input {input}")]
17    DomainError { function: String, input: Decimal },
18    #[error("Math error: {message}")]
19    MathError { message: String },
20}
21
22/// Helper function for single-argument f64 calculations
23fn f64_calc_1<F>(args: &[Decimal], func: F) -> Result<Decimal, FuncError>
24where
25    F: Fn(f64) -> f64,
26{
27    let arg = args[0].to_f64().ok_or(FuncError::DecimalToF64Conversion)?;
28    let result = func(arg);
29    Decimal::from_f64(result).ok_or(FuncError::F64ToDecimalConversion)
30}
31
32/// Helper function for two-argument f64 calculations
33fn f64_calc_2<F>(args: &[Decimal], func: F) -> Result<Decimal, FuncError>
34where
35    F: Fn(f64, f64) -> f64,
36{
37    let arg1 = args[0].to_f64().ok_or(FuncError::DecimalToF64Conversion)?;
38    let arg2 = args[1].to_f64().ok_or(FuncError::DecimalToF64Conversion)?;
39    let result = func(arg1, arg2);
40    Decimal::from_f64(result).ok_or(FuncError::F64ToDecimalConversion)
41}
42
43/// Errors that can occur during symbol table operations.
44#[derive(Error, Debug, Clone, PartialEq, Eq)]
45pub enum SymbolError {
46    /// A symbol with this name already exists in the table.
47    #[error("Duplicate symbol definition: '{0}'")]
48    DuplicateSymbol(String),
49}
50
51/// A symbol representing either a constant or function.
52///
53/// Symbols are stored in a [`SymTable`] and referenced during evaluation.
54#[derive(Debug, Clone)]
55pub enum Symbol {
56    /// Named constant (e.g., `pi`).
57    Const {
58        name: Cow<'static, str>,
59        value: Decimal,
60        description: Option<Cow<'static, str>>,
61    },
62    /// Function with specified arity and callback.
63    Func {
64        name: Cow<'static, str>,
65        /// Minimum number of arguments
66        args: usize,
67        /// Whether the function accepts additional arguments
68        variadic: bool,
69        callback: fn(&[Decimal]) -> Result<Decimal, FuncError>,
70        description: Option<Cow<'static, str>>,
71    },
72}
73
74impl Symbol {
75    /// Returns the name of the symbol.
76    pub fn name(&self) -> &str {
77        match self {
78            Symbol::Const { name, .. } => name,
79            Symbol::Func { name, .. } => name,
80        }
81    }
82
83    /// Returns the description of the symbol, if available.
84    pub fn description(&self) -> Option<&str> {
85        match self {
86            Symbol::Const { description, .. } => description.as_deref(),
87            Symbol::Func { description, .. } => description.as_deref(),
88        }
89    }
90}
91
92/// Symbol table containing constants and functions.
93///
94/// The table stores mathematical constants like `pi` and functions like `sin`.
95/// Symbol lookups are case-insensitive.
96///
97/// # Examples
98///
99/// ```
100/// use expr_solver::SymTable;
101/// use rust_decimal_macros::dec;
102///
103/// let mut table = SymTable::stdlib();
104/// table.add_const("x", dec!(42)).unwrap();
105/// ```
106#[derive(Debug, Default, Clone)]
107pub struct SymTable {
108    symbols: Vec<Symbol>,
109}
110
111impl SymTable {
112    /// Creates an empty symbol table.
113    pub fn new() -> Self {
114        Self::default()
115    }
116
117    /// Creates a symbol table with the standard library.
118    ///
119    /// ## Constants
120    /// - `pi` - π (3.14159...)
121    /// - `e` - Euler's number (2.71828...)
122    /// - `tau` - 2π (6.28318...)
123    /// - `ln2` - Natural logarithm of 2
124    /// - `ln10` - Natural logarithm of 10
125    /// - `sqrt2` - Square root of 2
126    ///
127    /// ## Fixed arity functions
128    /// - `sin(x)` - Sine
129    /// - `cos(x)` - Cosine
130    /// - `tan(x)` - Tangent
131    /// - `asin(x)` - Arcsine
132    /// - `acos(x)` - Arccosine
133    /// - `atan(x)` - Arctangent
134    /// - `atan2(y, x)` - Two-argument arctangent
135    /// - `sinh(x)` - Hyperbolic sine
136    /// - `cosh(x)` - Hyperbolic cosine
137    /// - `tanh(x)` - Hyperbolic tangent
138    /// - `sqrt(x)` - Square root
139    /// - `cbrt(x)` - Cube root
140    /// - `pow(x, y)` - x raised to power y
141    /// - `log(x)` - Natural logarithm
142    /// - `log2(x)` - Base-2 logarithm
143    /// - `log10(x)` - Base-10 logarithm
144    /// - `exp(x)` - e raised to power x
145    /// - `exp2(x)` - 2 raised to power x
146    /// - `abs(x)` - Absolute value
147    /// - `sign(x)` - Sign function (-1, 0, or 1)
148    /// - `floor(x)` - Floor function
149    /// - `ceil(x)` - Ceiling function
150    /// - `round(x)` - Round to nearest integer
151    /// - `trunc(x)` - Truncate to integer
152    /// - `fract(x)` - Fractional part
153    /// - `mod(x, y)` - Remainder of x/y
154    /// - `hypot(x, y)` - Euclidean distance sqrt(x²+y²)
155    /// - `clamp(x, min, max)` - Constrain value between bounds
156    ///
157    /// ## Variadic functions
158    /// - `min(x, ...)` - Minimum value
159    /// - `max(x, ...)` - Maximum value
160    /// - `sum(x, ...)` - Sum of values
161    /// - `avg(x, ...)` - Average of values
162    pub fn stdlib() -> Self {
163        Self {
164            symbols: vec![
165                // Constants
166                Symbol::Const {
167                    name: "pi".into(),
168                    value: Decimal::PI,
169                    description: Some("π (3.14159...)".into()),
170                },
171                Symbol::Const {
172                    name: "e".into(),
173                    value: Decimal::E,
174                    description: Some("Euler's number (2.71828...)".into()),
175                },
176                Symbol::Const {
177                    name: "tau".into(),
178                    value: Decimal::TWO_PI,
179                    description: Some("2π (6.28318...)".into()),
180                },
181                Symbol::Const {
182                    name: "ln2".into(),
183                    value: Decimal::TWO.ln(),
184                    description: Some("Natural logarithm of 2".into()),
185                },
186                Symbol::Const {
187                    name: "ln10".into(),
188                    value: Decimal::TEN.log10(),
189                    description: Some("Natural logarithm of 10".into()),
190                },
191                Symbol::Const {
192                    name: "sqrt2".into(),
193                    value: Decimal::TWO.sqrt().unwrap(),
194                    description: Some("Square root of 2".into()),
195                },
196                // Trigonometric functions
197                Symbol::Func {
198                    name: "sin".into(),
199                    args: 1,
200                    variadic: false,
201                    callback: |args| Ok(args[0].sin()),
202                    description: Some("Sine".into()),
203                },
204                Symbol::Func {
205                    name: "cos".into(),
206                    args: 1,
207                    variadic: false,
208                    callback: |args| Ok(args[0].cos()),
209                    description: Some("Cosine".into()),
210                },
211                Symbol::Func {
212                    name: "tan".into(),
213                    args: 1,
214                    variadic: false,
215                    callback: |args| {
216                        let input = args[0];
217                        match panic::catch_unwind(panic::AssertUnwindSafe(|| input.tan())) {
218                            Ok(result) => Ok(result),
219                            Err(_) => Err(FuncError::DomainError {
220                                function: "tan".to_string(),
221                                input,
222                            }),
223                        }
224                    },
225                    description: Some("Tangent".into()),
226                },
227                Symbol::Func {
228                    name: "asin".into(),
229                    args: 1,
230                    variadic: false,
231                    callback: |args| f64_calc_1(args, |x| x.asin()),
232                    description: Some("Arcsine".into()),
233                },
234                Symbol::Func {
235                    name: "acos".into(),
236                    args: 1,
237                    variadic: false,
238                    callback: |args| f64_calc_1(args, |x| x.acos()),
239                    description: Some("Arccosine".into()),
240                },
241                Symbol::Func {
242                    name: "atan".into(),
243                    args: 1,
244                    variadic: false,
245                    callback: |args| f64_calc_1(args, |x| x.atan()),
246                    description: Some("Arctangent".into()),
247                },
248                Symbol::Func {
249                    name: "atan2".into(),
250                    args: 2,
251                    variadic: false,
252                    callback: |args| f64_calc_2(args, |y, x| y.atan2(x)),
253                    description: Some("Two-argument arctangent".into()),
254                },
255                Symbol::Func {
256                    name: "sinh".into(),
257                    args: 1,
258                    variadic: false,
259                    callback: |args| f64_calc_1(args, |x| x.sinh()),
260                    description: Some("Hyperbolic sine".into()),
261                },
262                Symbol::Func {
263                    name: "cosh".into(),
264                    args: 1,
265                    variadic: false,
266                    callback: |args| f64_calc_1(args, |x| x.cosh()),
267                    description: Some("Hyperbolic cosine".into()),
268                },
269                Symbol::Func {
270                    name: "tanh".into(),
271                    args: 1,
272                    variadic: false,
273                    callback: |args| f64_calc_1(args, |x| x.tanh()),
274                    description: Some("Hyperbolic tangent".into()),
275                },
276                // Power and root functions
277                Symbol::Func {
278                    name: "sqrt".into(),
279                    args: 1,
280                    variadic: false,
281                    callback: |args| {
282                        args[0]
283                            .sqrt()
284                            .ok_or_else(|| FuncError::NegativeSqrt { value: args[0] })
285                    },
286                    description: Some("Square root".into()),
287                },
288                Symbol::Func {
289                    name: "cbrt".into(),
290                    args: 1,
291                    variadic: false,
292                    callback: |args| f64_calc_1(args, |x| x.cbrt()),
293                    description: Some("Cube root".into()),
294                },
295                Symbol::Func {
296                    name: "pow".into(),
297                    args: 2,
298                    variadic: false,
299                    callback: |args| {
300                        let base = args[0];
301                        let exponent = args[1];
302                        match panic::catch_unwind(panic::AssertUnwindSafe(|| base.powd(exponent))) {
303                            Ok(result) => Ok(result),
304                            Err(_) => Err(FuncError::MathError {
305                                message: format!("Power operation failed: {}^{}", base, exponent),
306                            }),
307                        }
308                    },
309                    description: Some("x raised to power y".into()),
310                },
311                // Logarithmic and exponential functions
312                Symbol::Func {
313                    name: "log".into(),
314                    args: 1,
315                    variadic: false,
316                    callback: |args| {
317                        if args[0] <= Decimal::ZERO {
318                            Err(FuncError::DomainError {
319                                function: "log".to_string(),
320                                input: args[0],
321                            })
322                        } else {
323                            Ok(args[0].ln())
324                        }
325                    },
326                    description: Some("Natural logarithm".into()),
327                },
328                Symbol::Func {
329                    name: "log2".into(),
330                    args: 1,
331                    variadic: false,
332                    callback: |args| f64_calc_1(args, |x| x.log2()),
333                    description: Some("Base-2 logarithm".into()),
334                },
335                Symbol::Func {
336                    name: "log10".into(),
337                    args: 1,
338                    variadic: false,
339                    callback: |args| {
340                        if args[0] <= Decimal::ZERO {
341                            Err(FuncError::DomainError {
342                                function: "log10".to_string(),
343                                input: args[0],
344                            })
345                        } else {
346                            Ok(args[0].log10())
347                        }
348                    },
349                    description: Some("Base-10 logarithm".into()),
350                },
351                Symbol::Func {
352                    name: "exp".into(),
353                    args: 1,
354                    variadic: false,
355                    callback: |args| {
356                        let input = args[0];
357                        match panic::catch_unwind(panic::AssertUnwindSafe(|| input.exp())) {
358                            Ok(result) => Ok(result),
359                            Err(_) => Err(FuncError::MathError {
360                                message: "Exponential overflow or underflow".to_string(),
361                            }),
362                        }
363                    },
364                    description: Some("e raised to power x".into()),
365                },
366                Symbol::Func {
367                    name: "exp2".into(),
368                    args: 1,
369                    variadic: false,
370                    callback: |args| f64_calc_1(args, |x| x.exp2()),
371                    description: Some("2 raised to power x".into()),
372                },
373                // Basic math functions
374                Symbol::Func {
375                    name: "abs".into(),
376                    args: 1,
377                    variadic: false,
378                    callback: |args| Ok(args[0].abs()),
379                    description: Some("Absolute value".into()),
380                },
381                Symbol::Func {
382                    name: "sign".into(),
383                    args: 1,
384                    variadic: false,
385                    callback: |args| Ok(args[0].signum()),
386                    description: Some("Sign function (-1, 0, or 1)".into()),
387                },
388                Symbol::Func {
389                    name: "floor".into(),
390                    args: 1,
391                    variadic: false,
392                    callback: |args| Ok(args[0].floor()),
393                    description: Some("Floor function".into()),
394                },
395                Symbol::Func {
396                    name: "ceil".into(),
397                    args: 1,
398                    variadic: false,
399                    callback: |args| Ok(args[0].ceil()),
400                    description: Some("Ceiling function".into()),
401                },
402                Symbol::Func {
403                    name: "round".into(),
404                    args: 1,
405                    variadic: false,
406                    callback: |args| Ok(args[0].round()),
407                    description: Some("Round to nearest integer".into()),
408                },
409                Symbol::Func {
410                    name: "trunc".into(),
411                    args: 1,
412                    variadic: false,
413                    callback: |args| Ok(args[0].trunc()),
414                    description: Some("Truncate to integer".into()),
415                },
416                Symbol::Func {
417                    name: "fract".into(),
418                    args: 1,
419                    variadic: false,
420                    callback: |args| Ok(args[0].fract()),
421                    description: Some("Fractional part".into()),
422                },
423                Symbol::Func {
424                    name: "mod".into(),
425                    args: 2,
426                    variadic: false,
427                    callback: |args| Ok(args[0] % args[1]),
428                    description: Some("Remainder of x/y".into()),
429                },
430                Symbol::Func {
431                    name: "hypot".into(),
432                    args: 2,
433                    variadic: false,
434                    callback: |args| f64_calc_2(args, |x, y| x.hypot(y)),
435                    description: Some("Euclidean distance sqrt(x²+y²)".into()),
436                },
437                Symbol::Func {
438                    name: "clamp".into(),
439                    args: 3,
440                    variadic: false,
441                    callback: |args| Ok(args[0].clamp(args[1].min(args[2]), args[2].max(args[1]))),
442                    description: Some("Constrain value between bounds".into()),
443                },
444                Symbol::Func {
445                    name: "if".into(),
446                    args: 3,
447                    variadic: false,
448                    callback: |args| {
449                        if args[0] != Decimal::ZERO {
450                            Ok(args[1])
451                        } else {
452                            Ok(args[2])
453                        }
454                    },
455                    description: Some(
456                        "Conditional expression: if(condition, true_value, false_value)".into(),
457                    ),
458                },
459                // Variadic functions
460                Symbol::Func {
461                    name: "min".into(),
462                    args: 1,
463                    variadic: true,
464                    callback: |args| {
465                        Ok(*args.iter().min().ok_or_else(|| FuncError::MathError {
466                            message: "min() requires at least one argument".to_string(),
467                        })?)
468                    },
469                    description: Some("Minimum value".into()),
470                },
471                Symbol::Func {
472                    name: "max".into(),
473                    args: 1,
474                    variadic: true,
475                    callback: |args| {
476                        Ok(*args.iter().max().ok_or_else(|| FuncError::MathError {
477                            message: "max() requires at least one argument".to_string(),
478                        })?)
479                    },
480                    description: Some("Maximum value".into()),
481                },
482                Symbol::Func {
483                    name: "sum".into(),
484                    args: 1,
485                    variadic: true,
486                    callback: |args| Ok(args.iter().sum()),
487                    description: Some("Sum of values".into()),
488                },
489                Symbol::Func {
490                    name: "avg".into(),
491                    args: 1,
492                    variadic: true,
493                    callback: |args| {
494                        let sum: Decimal = args.iter().sum();
495                        let count = Decimal::from(args.len());
496                        Ok(sum / count)
497                    },
498                    description: Some("Average of values".into()),
499                },
500            ],
501        }
502    }
503
504    /// Adds a constant to the table.
505    ///
506    /// Returns an error if a symbol with the same name already exists.
507    pub fn add_const<S: Into<Cow<'static, str>>>(
508        &mut self,
509        name: S,
510        value: Decimal,
511    ) -> Result<&mut Self, SymbolError> {
512        let name = name.into();
513        if self.get(&name).is_some() {
514            return Err(SymbolError::DuplicateSymbol(name.to_string()));
515        }
516        self.symbols.push(Symbol::Const {
517            name,
518            value,
519            description: None,
520        });
521        Ok(self)
522    }
523
524    /// Adds a function to the table.
525    ///
526    /// # Parameters
527    /// - `name`: Function name
528    /// - `args`: Minimum number of arguments
529    /// - `variadic`: Whether the function accepts additional arguments
530    /// - `callback`: Function implementation
531    ///
532    /// Returns an error if a symbol with the same name already exists.
533    pub fn add_func<S: Into<Cow<'static, str>>>(
534        &mut self,
535        name: S,
536        args: usize,
537        variadic: bool,
538        callback: fn(&[Decimal]) -> Result<Decimal, FuncError>,
539    ) -> Result<&mut Self, SymbolError> {
540        let name = name.into();
541        if self.get(&name).is_some() {
542            return Err(SymbolError::DuplicateSymbol(name.to_string()));
543        }
544        self.symbols.push(Symbol::Func {
545            name,
546            args,
547            variadic,
548            callback,
549            description: None,
550        });
551        Ok(self)
552    }
553
554    /// Looks up a symbol by name (case-insensitive).
555    pub fn get(&self, name: &str) -> Option<&Symbol> {
556        self.symbols
557            .iter()
558            .find(|sym| sym.name().eq_ignore_ascii_case(name))
559    }
560
561    /// Returns an iterator over all symbols in the table.
562    pub fn symbols(&self) -> impl Iterator<Item = &Symbol> {
563        self.symbols.iter()
564    }
565}