1use rust_decimal::Decimal;
2use rust_decimal::prelude::*;
3use std::borrow::Cow;
4use std::panic;
5use thiserror::Error;
6
7#[derive(Error, Debug, Clone)]
9pub enum FuncError {
10 #[error("Conversion error: failed to convert Decimal to f64")]
11 DecimalToF64Conversion,
12 #[error("Conversion error: failed to convert f64 result back to Decimal")]
13 F64ToDecimalConversion,
14 #[error("Square root of negative number: {value}")]
15 NegativeSqrt { value: Decimal },
16 #[error("Domain error in function '{function}': invalid input {input}")]
17 DomainError { function: String, input: Decimal },
18 #[error("Math error: {message}")]
19 MathError { message: String },
20}
21
22fn f64_calc_1<F>(args: &[Decimal], func: F) -> Result<Decimal, FuncError>
24where
25 F: Fn(f64) -> f64,
26{
27 let arg = args[0].to_f64().ok_or(FuncError::DecimalToF64Conversion)?;
28 let result = func(arg);
29 Decimal::from_f64(result).ok_or(FuncError::F64ToDecimalConversion)
30}
31
32fn f64_calc_2<F>(args: &[Decimal], func: F) -> Result<Decimal, FuncError>
34where
35 F: Fn(f64, f64) -> f64,
36{
37 let arg1 = args[0].to_f64().ok_or(FuncError::DecimalToF64Conversion)?;
38 let arg2 = args[1].to_f64().ok_or(FuncError::DecimalToF64Conversion)?;
39 let result = func(arg1, arg2);
40 Decimal::from_f64(result).ok_or(FuncError::F64ToDecimalConversion)
41}
42
43#[derive(Error, Debug, Clone, PartialEq, Eq)]
45pub enum SymbolError {
46 #[error("Duplicate symbol definition: '{0}'")]
48 DuplicateSymbol(String),
49}
50
51#[derive(Debug, Clone)]
53pub enum Symbol {
54 Const {
56 name: Cow<'static, str>,
57 value: Decimal,
58 description: Option<Cow<'static, str>>,
59 },
60 Func {
62 name: Cow<'static, str>,
63 args: usize,
64 variadic: bool,
65 callback: fn(&[Decimal]) -> Result<Decimal, FuncError>,
66 description: Option<Cow<'static, str>>,
67 },
68}
69
70impl Symbol {
71 pub fn name(&self) -> &str {
72 match self {
73 Symbol::Const { name, .. } => name,
74 Symbol::Func { name, .. } => name,
75 }
76 }
77
78 pub fn description(&self) -> Option<&str> {
79 match self {
80 Symbol::Const { description, .. } => description.as_deref(),
81 Symbol::Func { description, .. } => description.as_deref(),
82 }
83 }
84}
85
86#[derive(Debug, Default, Clone)]
88pub struct SymTable {
89 symbols: Vec<Symbol>,
90}
91
92impl SymTable {
93 pub fn new() -> Self {
95 Self::default()
96 }
97
98 pub fn stdlib() -> Self {
144 Self {
145 symbols: vec![
146 Symbol::Const {
148 name: "pi".into(),
149 value: Decimal::PI,
150 description: Some("π (3.14159...)".into()),
151 },
152 Symbol::Const {
153 name: "e".into(),
154 value: Decimal::E,
155 description: Some("Euler's number (2.71828...)".into()),
156 },
157 Symbol::Const {
158 name: "tau".into(),
159 value: Decimal::TWO_PI,
160 description: Some("2π (6.28318...)".into()),
161 },
162 Symbol::Const {
163 name: "ln2".into(),
164 value: Decimal::from_f64(std::f64::consts::LN_2).unwrap(),
165 description: Some("Natural logarithm of 2".into()),
166 },
167 Symbol::Const {
168 name: "ln10".into(),
169 value: Decimal::from_f64(std::f64::consts::LN_10).unwrap(),
170 description: Some("Natural logarithm of 10".into()),
171 },
172 Symbol::Const {
173 name: "sqrt2".into(),
174 value: Decimal::from_f64(std::f64::consts::SQRT_2).unwrap(),
175 description: Some("Square root of 2".into()),
176 },
177 Symbol::Func {
179 name: "sin".into(),
180 args: 1,
181 variadic: false,
182 callback: |args| Ok(args[0].sin()),
183 description: Some("Sine".into()),
184 },
185 Symbol::Func {
186 name: "cos".into(),
187 args: 1,
188 variadic: false,
189 callback: |args| Ok(args[0].cos()),
190 description: Some("Cosine".into()),
191 },
192 Symbol::Func {
193 name: "tan".into(),
194 args: 1,
195 variadic: false,
196 callback: |args| {
197 let input = args[0];
198 match panic::catch_unwind(panic::AssertUnwindSafe(|| input.tan())) {
199 Ok(result) => Ok(result),
200 Err(_) => Err(FuncError::DomainError {
201 function: "tan".to_string(),
202 input,
203 }),
204 }
205 },
206 description: Some("Tangent".into()),
207 },
208 Symbol::Func {
209 name: "asin".into(),
210 args: 1,
211 variadic: false,
212 callback: |args| f64_calc_1(args, |x| x.asin()),
213 description: Some("Arcsine".into()),
214 },
215 Symbol::Func {
216 name: "acos".into(),
217 args: 1,
218 variadic: false,
219 callback: |args| f64_calc_1(args, |x| x.acos()),
220 description: Some("Arccosine".into()),
221 },
222 Symbol::Func {
223 name: "atan".into(),
224 args: 1,
225 variadic: false,
226 callback: |args| f64_calc_1(args, |x| x.atan()),
227 description: Some("Arctangent".into()),
228 },
229 Symbol::Func {
230 name: "atan2".into(),
231 args: 2,
232 variadic: false,
233 callback: |args| f64_calc_2(args, |y, x| y.atan2(x)),
234 description: Some("Two-argument arctangent".into()),
235 },
236 Symbol::Func {
237 name: "sinh".into(),
238 args: 1,
239 variadic: false,
240 callback: |args| f64_calc_1(args, |x| x.sinh()),
241 description: Some("Hyperbolic sine".into()),
242 },
243 Symbol::Func {
244 name: "cosh".into(),
245 args: 1,
246 variadic: false,
247 callback: |args| f64_calc_1(args, |x| x.cosh()),
248 description: Some("Hyperbolic cosine".into()),
249 },
250 Symbol::Func {
251 name: "tanh".into(),
252 args: 1,
253 variadic: false,
254 callback: |args| f64_calc_1(args, |x| x.tanh()),
255 description: Some("Hyperbolic tangent".into()),
256 },
257 Symbol::Func {
259 name: "sqrt".into(),
260 args: 1,
261 variadic: false,
262 callback: |args| {
263 args[0]
264 .sqrt()
265 .ok_or_else(|| FuncError::NegativeSqrt { value: args[0] })
266 },
267 description: Some("Square root".into()),
268 },
269 Symbol::Func {
270 name: "cbrt".into(),
271 args: 1,
272 variadic: false,
273 callback: |args| f64_calc_1(args, |x| x.cbrt()),
274 description: Some("Cube root".into()),
275 },
276 Symbol::Func {
277 name: "pow".into(),
278 args: 2,
279 variadic: false,
280 callback: |args| {
281 let base = args[0];
282 let exponent = args[1];
283 match panic::catch_unwind(panic::AssertUnwindSafe(|| base.powd(exponent))) {
284 Ok(result) => Ok(result),
285 Err(_) => Err(FuncError::MathError {
286 message: format!("Power operation failed: {}^{}", base, exponent),
287 }),
288 }
289 },
290 description: Some("x raised to power y".into()),
291 },
292 Symbol::Func {
294 name: "log".into(),
295 args: 1,
296 variadic: false,
297 callback: |args| {
298 if args[0] <= Decimal::ZERO {
299 Err(FuncError::DomainError {
300 function: "log".to_string(),
301 input: args[0],
302 })
303 } else {
304 Ok(args[0].ln())
305 }
306 },
307 description: Some("Natural logarithm".into()),
308 },
309 Symbol::Func {
310 name: "log2".into(),
311 args: 1,
312 variadic: false,
313 callback: |args| f64_calc_1(args, |x| x.log2()),
314 description: Some("Base-2 logarithm".into()),
315 },
316 Symbol::Func {
317 name: "log10".into(),
318 args: 1,
319 variadic: false,
320 callback: |args| {
321 if args[0] <= Decimal::ZERO {
322 Err(FuncError::DomainError {
323 function: "log10".to_string(),
324 input: args[0],
325 })
326 } else {
327 Ok(args[0].log10())
328 }
329 },
330 description: Some("Base-10 logarithm".into()),
331 },
332 Symbol::Func {
333 name: "exp".into(),
334 args: 1,
335 variadic: false,
336 callback: |args| {
337 let input = args[0];
338 match panic::catch_unwind(panic::AssertUnwindSafe(|| input.exp())) {
339 Ok(result) => Ok(result),
340 Err(_) => Err(FuncError::MathError {
341 message: "Exponential overflow or underflow".to_string(),
342 }),
343 }
344 },
345 description: Some("e raised to power x".into()),
346 },
347 Symbol::Func {
348 name: "exp2".into(),
349 args: 1,
350 variadic: false,
351 callback: |args| f64_calc_1(args, |x| x.exp2()),
352 description: Some("2 raised to power x".into()),
353 },
354 Symbol::Func {
356 name: "abs".into(),
357 args: 1,
358 variadic: false,
359 callback: |args| Ok(args[0].abs()),
360 description: Some("Absolute value".into()),
361 },
362 Symbol::Func {
363 name: "sign".into(),
364 args: 1,
365 variadic: false,
366 callback: |args| Ok(args[0].signum()),
367 description: Some("Sign function (-1, 0, or 1)".into()),
368 },
369 Symbol::Func {
370 name: "floor".into(),
371 args: 1,
372 variadic: false,
373 callback: |args| Ok(args[0].floor()),
374 description: Some("Floor function".into()),
375 },
376 Symbol::Func {
377 name: "ceil".into(),
378 args: 1,
379 variadic: false,
380 callback: |args| Ok(args[0].ceil()),
381 description: Some("Ceiling function".into()),
382 },
383 Symbol::Func {
384 name: "round".into(),
385 args: 1,
386 variadic: false,
387 callback: |args| Ok(args[0].round()),
388 description: Some("Round to nearest integer".into()),
389 },
390 Symbol::Func {
391 name: "trunc".into(),
392 args: 1,
393 variadic: false,
394 callback: |args| Ok(args[0].trunc()),
395 description: Some("Truncate to integer".into()),
396 },
397 Symbol::Func {
398 name: "fract".into(),
399 args: 1,
400 variadic: false,
401 callback: |args| Ok(args[0].fract()),
402 description: Some("Fractional part".into()),
403 },
404 Symbol::Func {
405 name: "mod".into(),
406 args: 2,
407 variadic: false,
408 callback: |args| Ok(args[0] % args[1]),
409 description: Some("Remainder of x/y".into()),
410 },
411 Symbol::Func {
412 name: "hypot".into(),
413 args: 2,
414 variadic: false,
415 callback: |args| f64_calc_2(args, |x, y| x.hypot(y)),
416 description: Some("Euclidean distance sqrt(x²+y²)".into()),
417 },
418 Symbol::Func {
419 name: "clamp".into(),
420 args: 3,
421 variadic: false,
422 callback: |args| Ok(args[0].clamp(args[1].min(args[2]), args[2].max(args[1]))),
423 description: Some("Constrain value between bounds".into()),
424 },
425 Symbol::Func {
426 name: "if".into(),
427 args: 3,
428 variadic: false,
429 callback: |args| {
430 if args[0] != Decimal::ZERO {
431 Ok(args[1])
432 } else {
433 Ok(args[2])
434 }
435 },
436 description: Some("Conditional expression: if(condition, true_value, false_value)".into()),
437 },
438 Symbol::Func {
440 name: "min".into(),
441 args: 1,
442 variadic: true,
443 callback: |args| {
444 Ok(*args.iter().min().ok_or_else(|| FuncError::MathError {
445 message: "min() requires at least one argument".to_string(),
446 })?)
447 },
448 description: Some("Minimum value".into()),
449 },
450 Symbol::Func {
451 name: "max".into(),
452 args: 1,
453 variadic: true,
454 callback: |args| {
455 Ok(*args.iter().max().ok_or_else(|| FuncError::MathError {
456 message: "max() requires at least one argument".to_string(),
457 })?)
458 },
459 description: Some("Maximum value".into()),
460 },
461 Symbol::Func {
462 name: "sum".into(),
463 args: 1,
464 variadic: true,
465 callback: |args| Ok(args.iter().sum()),
466 description: Some("Sum of values".into()),
467 },
468 Symbol::Func {
469 name: "avg".into(),
470 args: 1,
471 variadic: true,
472 callback: |args| {
473 let sum: Decimal = args.iter().sum();
474 let count = Decimal::from(args.len());
475 Ok(sum / count)
476 },
477 description: Some("Average of values".into()),
478 },
479 ],
480 }
481 }
482
483 pub fn add_const<S: Into<Cow<'static, str>>>(
485 &mut self,
486 name: S,
487 value: Decimal,
488 ) -> Result<&mut Self, SymbolError> {
489 let name = name.into();
490 if self.get(&name).is_some() {
491 return Err(SymbolError::DuplicateSymbol(name.to_string()));
492 }
493 self.symbols.push(Symbol::Const {
494 name,
495 value,
496 description: None,
497 });
498 Ok(self)
499 }
500
501 pub fn add_func<S: Into<Cow<'static, str>>>(
505 &mut self,
506 name: S,
507 args: usize,
508 variadic: bool,
509 callback: fn(&[Decimal]) -> Result<Decimal, FuncError>,
510 ) -> Result<&mut Self, SymbolError> {
511 let name = name.into();
512 if self.get(&name).is_some() {
513 return Err(SymbolError::DuplicateSymbol(name.to_string()));
514 }
515 self.symbols.push(Symbol::Func {
516 name,
517 args,
518 variadic,
519 callback,
520 description: None,
521 });
522 Ok(self)
523 }
524
525 pub fn get(&self, name: &str) -> Option<&Symbol> {
527 self.symbols
528 .iter()
529 .find(|sym| sym.name().eq_ignore_ascii_case(name))
530 }
531
532 pub fn symbols(&self) -> impl Iterator<Item = &Symbol> {
534 self.symbols.iter()
535 }
536}