expr_solver/
symbol.rs

1use rust_decimal::Decimal;
2use rust_decimal::prelude::*;
3use std::borrow::Cow;
4use std::panic;
5use thiserror::Error;
6
7/// Errors that can occur during function evaluation.
8#[derive(Error, Debug, Clone)]
9pub enum FuncError {
10    #[error("Conversion error: failed to convert Decimal to f64")]
11    DecimalToF64Conversion,
12    #[error("Conversion error: failed to convert f64 result back to Decimal")]
13    F64ToDecimalConversion,
14    #[error("Square root of negative number: {value}")]
15    NegativeSqrt { value: Decimal },
16    #[error("Domain error in function '{function}': invalid input {input}")]
17    DomainError { function: String, input: Decimal },
18    #[error("Math error: {message}")]
19    MathError { message: String },
20}
21
22/// Helper function for single-argument f64 calculations
23fn f64_calc_1<F>(args: &[Decimal], func: F) -> Result<Decimal, FuncError>
24where
25    F: Fn(f64) -> f64,
26{
27    let arg = args[0].to_f64().ok_or(FuncError::DecimalToF64Conversion)?;
28    let result = func(arg);
29    Decimal::from_f64(result).ok_or(FuncError::F64ToDecimalConversion)
30}
31
32/// Helper function for two-argument f64 calculations
33fn f64_calc_2<F>(args: &[Decimal], func: F) -> Result<Decimal, FuncError>
34where
35    F: Fn(f64, f64) -> f64,
36{
37    let arg1 = args[0].to_f64().ok_or(FuncError::DecimalToF64Conversion)?;
38    let arg2 = args[1].to_f64().ok_or(FuncError::DecimalToF64Conversion)?;
39    let result = func(arg1, arg2);
40    Decimal::from_f64(result).ok_or(FuncError::F64ToDecimalConversion)
41}
42
43/// Errors that can occur during symbol table operations.
44#[derive(Error, Debug, Clone, PartialEq, Eq)]
45pub enum SymbolError {
46    /// A symbol with this name already exists in the table.
47    #[error("Duplicate symbol definition: '{0}'")]
48    DuplicateSymbol(String),
49}
50
51/// What a symbol represents in the language.
52#[derive(Debug, Clone)]
53pub enum Symbol {
54    /// Named constant (e.g., `pi`). Returns a Decimal for now.
55    Const {
56        name: Cow<'static, str>,
57        value: Decimal,
58        description: Option<Cow<'static, str>>,
59    },
60    /// Function with specified arity and callback.
61    Func {
62        name: Cow<'static, str>,
63        args: usize,
64        variadic: bool,
65        callback: fn(&[Decimal]) -> Result<Decimal, FuncError>,
66        description: Option<Cow<'static, str>>,
67    },
68}
69
70impl Symbol {
71    pub fn name(&self) -> &str {
72        match self {
73            Symbol::Const { name, .. } => name,
74            Symbol::Func { name, .. } => name,
75        }
76    }
77
78    pub fn description(&self) -> Option<&str> {
79        match self {
80            Symbol::Const { description, .. } => description.as_deref(),
81            Symbol::Func { description, .. } => description.as_deref(),
82        }
83    }
84}
85
86/// Symbol table containing constants and functions.
87#[derive(Debug, Default, Clone)]
88pub struct SymTable {
89    symbols: Vec<Symbol>,
90}
91
92impl SymTable {
93    /// Create an empty symbol table.
94    pub fn new() -> Self {
95        Self::default()
96    }
97
98    /// Create a symbol table with built-in constants and functions.
99    ///
100    /// ## Constants
101    /// - `pi` - π (3.14159...)
102    /// - `e` - Euler's number (2.71828...)
103    /// - `tau` - 2π (6.28318...)
104    /// - `ln2` - Natural logarithm of 2
105    /// - `ln10` - Natural logarithm of 10
106    /// - `sqrt2` - Square root of 2
107    ///
108    /// ## Fixed arity functions
109    /// - `sin(x)` - Sine
110    /// - `cos(x)` - Cosine
111    /// - `tan(x)` - Tangent
112    /// - `asin(x)` - Arcsine
113    /// - `acos(x)` - Arccosine
114    /// - `atan(x)` - Arctangent
115    /// - `atan2(y, x)` - Two-argument arctangent
116    /// - `sinh(x)` - Hyperbolic sine
117    /// - `cosh(x)` - Hyperbolic cosine
118    /// - `tanh(x)` - Hyperbolic tangent
119    /// - `sqrt(x)` - Square root
120    /// - `cbrt(x)` - Cube root
121    /// - `pow(x, y)` - x raised to power y
122    /// - `log(x)` - Natural logarithm
123    /// - `log2(x)` - Base-2 logarithm
124    /// - `log10(x)` - Base-10 logarithm
125    /// - `exp(x)` - e raised to power x
126    /// - `exp2(x)` - 2 raised to power x
127    /// - `abs(x)` - Absolute value
128    /// - `sign(x)` - Sign function (-1, 0, or 1)
129    /// - `floor(x)` - Floor function
130    /// - `ceil(x)` - Ceiling function
131    /// - `round(x)` - Round to nearest integer
132    /// - `trunc(x)` - Truncate to integer
133    /// - `fract(x)` - Fractional part
134    /// - `mod(x, y)` - Remainder of x/y
135    /// - `hypot(x, y)` - Euclidean distance sqrt(x²+y²)
136    /// - `clamp(x, min, max)` - Constrain value between bounds
137    ///
138    /// ## Variadic functions
139    /// - `min(x, ...)` - Minimum value
140    /// - `max(x, ...)` - Maximum value
141    /// - `sum(x, ...)` - Sum of values
142    /// - `avg(x, ...)` - Average of values
143    pub fn stdlib() -> Self {
144        Self {
145            symbols: vec![
146                // Constants
147                Symbol::Const {
148                    name: "pi".into(),
149                    value: Decimal::PI,
150                    description: Some("π (3.14159...)".into()),
151                },
152                Symbol::Const {
153                    name: "e".into(),
154                    value: Decimal::E,
155                    description: Some("Euler's number (2.71828...)".into()),
156                },
157                Symbol::Const {
158                    name: "tau".into(),
159                    value: Decimal::TWO_PI,
160                    description: Some("2π (6.28318...)".into()),
161                },
162                Symbol::Const {
163                    name: "ln2".into(),
164                    value: Decimal::from_f64(std::f64::consts::LN_2).unwrap(),
165                    description: Some("Natural logarithm of 2".into()),
166                },
167                Symbol::Const {
168                    name: "ln10".into(),
169                    value: Decimal::from_f64(std::f64::consts::LN_10).unwrap(),
170                    description: Some("Natural logarithm of 10".into()),
171                },
172                Symbol::Const {
173                    name: "sqrt2".into(),
174                    value: Decimal::from_f64(std::f64::consts::SQRT_2).unwrap(),
175                    description: Some("Square root of 2".into()),
176                },
177                // Trigonometric functions
178                Symbol::Func {
179                    name: "sin".into(),
180                    args: 1,
181                    variadic: false,
182                    callback: |args| Ok(args[0].sin()),
183                    description: Some("Sine".into()),
184                },
185                Symbol::Func {
186                    name: "cos".into(),
187                    args: 1,
188                    variadic: false,
189                    callback: |args| Ok(args[0].cos()),
190                    description: Some("Cosine".into()),
191                },
192                Symbol::Func {
193                    name: "tan".into(),
194                    args: 1,
195                    variadic: false,
196                    callback: |args| {
197                        let input = args[0];
198                        match panic::catch_unwind(panic::AssertUnwindSafe(|| input.tan())) {
199                            Ok(result) => Ok(result),
200                            Err(_) => Err(FuncError::DomainError {
201                                function: "tan".to_string(),
202                                input,
203                            }),
204                        }
205                    },
206                    description: Some("Tangent".into()),
207                },
208                Symbol::Func {
209                    name: "asin".into(),
210                    args: 1,
211                    variadic: false,
212                    callback: |args| f64_calc_1(args, |x| x.asin()),
213                    description: Some("Arcsine".into()),
214                },
215                Symbol::Func {
216                    name: "acos".into(),
217                    args: 1,
218                    variadic: false,
219                    callback: |args| f64_calc_1(args, |x| x.acos()),
220                    description: Some("Arccosine".into()),
221                },
222                Symbol::Func {
223                    name: "atan".into(),
224                    args: 1,
225                    variadic: false,
226                    callback: |args| f64_calc_1(args, |x| x.atan()),
227                    description: Some("Arctangent".into()),
228                },
229                Symbol::Func {
230                    name: "atan2".into(),
231                    args: 2,
232                    variadic: false,
233                    callback: |args| f64_calc_2(args, |y, x| y.atan2(x)),
234                    description: Some("Two-argument arctangent".into()),
235                },
236                Symbol::Func {
237                    name: "sinh".into(),
238                    args: 1,
239                    variadic: false,
240                    callback: |args| f64_calc_1(args, |x| x.sinh()),
241                    description: Some("Hyperbolic sine".into()),
242                },
243                Symbol::Func {
244                    name: "cosh".into(),
245                    args: 1,
246                    variadic: false,
247                    callback: |args| f64_calc_1(args, |x| x.cosh()),
248                    description: Some("Hyperbolic cosine".into()),
249                },
250                Symbol::Func {
251                    name: "tanh".into(),
252                    args: 1,
253                    variadic: false,
254                    callback: |args| f64_calc_1(args, |x| x.tanh()),
255                    description: Some("Hyperbolic tangent".into()),
256                },
257                // Power and root functions
258                Symbol::Func {
259                    name: "sqrt".into(),
260                    args: 1,
261                    variadic: false,
262                    callback: |args| {
263                        args[0]
264                            .sqrt()
265                            .ok_or_else(|| FuncError::NegativeSqrt { value: args[0] })
266                    },
267                    description: Some("Square root".into()),
268                },
269                Symbol::Func {
270                    name: "cbrt".into(),
271                    args: 1,
272                    variadic: false,
273                    callback: |args| f64_calc_1(args, |x| x.cbrt()),
274                    description: Some("Cube root".into()),
275                },
276                Symbol::Func {
277                    name: "pow".into(),
278                    args: 2,
279                    variadic: false,
280                    callback: |args| {
281                        let base = args[0];
282                        let exponent = args[1];
283                        match panic::catch_unwind(panic::AssertUnwindSafe(|| base.powd(exponent))) {
284                            Ok(result) => Ok(result),
285                            Err(_) => Err(FuncError::MathError {
286                                message: format!("Power operation failed: {}^{}", base, exponent),
287                            }),
288                        }
289                    },
290                    description: Some("x raised to power y".into()),
291                },
292                // Logarithmic and exponential functions
293                Symbol::Func {
294                    name: "log".into(),
295                    args: 1,
296                    variadic: false,
297                    callback: |args| {
298                        if args[0] <= Decimal::ZERO {
299                            Err(FuncError::DomainError {
300                                function: "log".to_string(),
301                                input: args[0],
302                            })
303                        } else {
304                            Ok(args[0].ln())
305                        }
306                    },
307                    description: Some("Natural logarithm".into()),
308                },
309                Symbol::Func {
310                    name: "log2".into(),
311                    args: 1,
312                    variadic: false,
313                    callback: |args| f64_calc_1(args, |x| x.log2()),
314                    description: Some("Base-2 logarithm".into()),
315                },
316                Symbol::Func {
317                    name: "log10".into(),
318                    args: 1,
319                    variadic: false,
320                    callback: |args| {
321                        if args[0] <= Decimal::ZERO {
322                            Err(FuncError::DomainError {
323                                function: "log10".to_string(),
324                                input: args[0],
325                            })
326                        } else {
327                            Ok(args[0].log10())
328                        }
329                    },
330                    description: Some("Base-10 logarithm".into()),
331                },
332                Symbol::Func {
333                    name: "exp".into(),
334                    args: 1,
335                    variadic: false,
336                    callback: |args| {
337                        let input = args[0];
338                        match panic::catch_unwind(panic::AssertUnwindSafe(|| input.exp())) {
339                            Ok(result) => Ok(result),
340                            Err(_) => Err(FuncError::MathError {
341                                message: "Exponential overflow or underflow".to_string(),
342                            }),
343                        }
344                    },
345                    description: Some("e raised to power x".into()),
346                },
347                Symbol::Func {
348                    name: "exp2".into(),
349                    args: 1,
350                    variadic: false,
351                    callback: |args| f64_calc_1(args, |x| x.exp2()),
352                    description: Some("2 raised to power x".into()),
353                },
354                // Basic math functions
355                Symbol::Func {
356                    name: "abs".into(),
357                    args: 1,
358                    variadic: false,
359                    callback: |args| Ok(args[0].abs()),
360                    description: Some("Absolute value".into()),
361                },
362                Symbol::Func {
363                    name: "sign".into(),
364                    args: 1,
365                    variadic: false,
366                    callback: |args| Ok(args[0].signum()),
367                    description: Some("Sign function (-1, 0, or 1)".into()),
368                },
369                Symbol::Func {
370                    name: "floor".into(),
371                    args: 1,
372                    variadic: false,
373                    callback: |args| Ok(args[0].floor()),
374                    description: Some("Floor function".into()),
375                },
376                Symbol::Func {
377                    name: "ceil".into(),
378                    args: 1,
379                    variadic: false,
380                    callback: |args| Ok(args[0].ceil()),
381                    description: Some("Ceiling function".into()),
382                },
383                Symbol::Func {
384                    name: "round".into(),
385                    args: 1,
386                    variadic: false,
387                    callback: |args| Ok(args[0].round()),
388                    description: Some("Round to nearest integer".into()),
389                },
390                Symbol::Func {
391                    name: "trunc".into(),
392                    args: 1,
393                    variadic: false,
394                    callback: |args| Ok(args[0].trunc()),
395                    description: Some("Truncate to integer".into()),
396                },
397                Symbol::Func {
398                    name: "fract".into(),
399                    args: 1,
400                    variadic: false,
401                    callback: |args| Ok(args[0].fract()),
402                    description: Some("Fractional part".into()),
403                },
404                Symbol::Func {
405                    name: "mod".into(),
406                    args: 2,
407                    variadic: false,
408                    callback: |args| Ok(args[0] % args[1]),
409                    description: Some("Remainder of x/y".into()),
410                },
411                Symbol::Func {
412                    name: "hypot".into(),
413                    args: 2,
414                    variadic: false,
415                    callback: |args| f64_calc_2(args, |x, y| x.hypot(y)),
416                    description: Some("Euclidean distance sqrt(x²+y²)".into()),
417                },
418                Symbol::Func {
419                    name: "clamp".into(),
420                    args: 3,
421                    variadic: false,
422                    callback: |args| Ok(args[0].clamp(args[1].min(args[2]), args[2].max(args[1]))),
423                    description: Some("Constrain value between bounds".into()),
424                },
425                Symbol::Func {
426                    name: "if".into(),
427                    args: 3,
428                    variadic: false,
429                    callback: |args| {
430                        if args[0] != Decimal::ZERO {
431                            Ok(args[1])
432                        } else {
433                            Ok(args[2])
434                        }
435                    },
436                    description: Some("Conditional expression: if(condition, true_value, false_value)".into()),
437                },
438                // Variadic functions
439                Symbol::Func {
440                    name: "min".into(),
441                    args: 1,
442                    variadic: true,
443                    callback: |args| {
444                        Ok(*args.iter().min().ok_or_else(|| FuncError::MathError {
445                            message: "min() requires at least one argument".to_string(),
446                        })?)
447                    },
448                    description: Some("Minimum value".into()),
449                },
450                Symbol::Func {
451                    name: "max".into(),
452                    args: 1,
453                    variadic: true,
454                    callback: |args| {
455                        Ok(*args.iter().max().ok_or_else(|| FuncError::MathError {
456                            message: "max() requires at least one argument".to_string(),
457                        })?)
458                    },
459                    description: Some("Maximum value".into()),
460                },
461                Symbol::Func {
462                    name: "sum".into(),
463                    args: 1,
464                    variadic: true,
465                    callback: |args| Ok(args.iter().sum()),
466                    description: Some("Sum of values".into()),
467                },
468                Symbol::Func {
469                    name: "avg".into(),
470                    args: 1,
471                    variadic: true,
472                    callback: |args| {
473                        let sum: Decimal = args.iter().sum();
474                        let count = Decimal::from(args.len());
475                        Ok(sum / count)
476                    },
477                    description: Some("Average of values".into()),
478                },
479            ],
480        }
481    }
482
483    /// Add a constant to the table.
484    pub fn add_const<S: Into<Cow<'static, str>>>(
485        &mut self,
486        name: S,
487        value: Decimal,
488    ) -> Result<&mut Self, SymbolError> {
489        let name = name.into();
490        if self.get(&name).is_some() {
491            return Err(SymbolError::DuplicateSymbol(name.to_string()));
492        }
493        self.symbols.push(Symbol::Const {
494            name,
495            value,
496            description: None,
497        });
498        Ok(self)
499    }
500
501    /// Add a function to the table.
502    ///
503    /// Returns an error if a symbol with the same name already exists.
504    pub fn add_func<S: Into<Cow<'static, str>>>(
505        &mut self,
506        name: S,
507        args: usize,
508        variadic: bool,
509        callback: fn(&[Decimal]) -> Result<Decimal, FuncError>,
510    ) -> Result<&mut Self, SymbolError> {
511        let name = name.into();
512        if self.get(&name).is_some() {
513            return Err(SymbolError::DuplicateSymbol(name.to_string()));
514        }
515        self.symbols.push(Symbol::Func {
516            name,
517            args,
518            variadic,
519            callback,
520            description: None,
521        });
522        Ok(self)
523    }
524
525    /// Look up a symbol by name (case-insensitive).
526    pub fn get(&self, name: &str) -> Option<&Symbol> {
527        self.symbols
528            .iter()
529            .find(|sym| sym.name().eq_ignore_ascii_case(name))
530    }
531
532    /// Get an iterator over all symbols in the table.
533    pub fn symbols(&self) -> impl Iterator<Item = &Symbol> {
534        self.symbols.iter()
535    }
536}