1use rust_decimal::Decimal;
2use rust_decimal::prelude::*;
3use std::borrow::Cow;
4use std::panic;
5use thiserror::Error;
6
7#[derive(Error, Debug, Clone)]
9pub enum FuncError {
10 #[error("Conversion error: failed to convert Decimal to f64")]
11 DecimalToF64Conversion,
12 #[error("Conversion error: failed to convert f64 result back to Decimal")]
13 F64ToDecimalConversion,
14 #[error("Square root of negative number: {value}")]
15 NegativeSqrt { value: Decimal },
16 #[error("Domain error in function '{function}': invalid input {input}")]
17 DomainError { function: String, input: Decimal },
18 #[error("Math error: {message}")]
19 MathError { message: String },
20}
21
22fn f64_calc_1<F>(args: &[Decimal], func: F) -> Result<Decimal, FuncError>
24where
25 F: Fn(f64) -> f64,
26{
27 let arg = args[0].to_f64().ok_or(FuncError::DecimalToF64Conversion)?;
28 let result = func(arg);
29 Decimal::from_f64(result).ok_or(FuncError::F64ToDecimalConversion)
30}
31
32fn f64_calc_2<F>(args: &[Decimal], func: F) -> Result<Decimal, FuncError>
34where
35 F: Fn(f64, f64) -> f64,
36{
37 let arg1 = args[0].to_f64().ok_or(FuncError::DecimalToF64Conversion)?;
38 let arg2 = args[1].to_f64().ok_or(FuncError::DecimalToF64Conversion)?;
39 let result = func(arg1, arg2);
40 Decimal::from_f64(result).ok_or(FuncError::F64ToDecimalConversion)
41}
42
43#[derive(Error, Debug, Clone, PartialEq, Eq)]
45pub enum SymbolError {
46 #[error("Duplicate symbol definition: '{0}'")]
48 DuplicateSymbol(String),
49}
50
51#[derive(Debug, Clone)]
55pub enum Symbol {
56 Const {
58 name: Cow<'static, str>,
59 value: Decimal,
60 description: Option<Cow<'static, str>>,
61 },
62 Func {
64 name: Cow<'static, str>,
65 args: usize,
67 variadic: bool,
69 callback: fn(&[Decimal]) -> Result<Decimal, FuncError>,
70 description: Option<Cow<'static, str>>,
71 },
72}
73
74impl Symbol {
75 pub fn name(&self) -> &str {
77 match self {
78 Symbol::Const { name, .. } => name,
79 Symbol::Func { name, .. } => name,
80 }
81 }
82
83 pub fn description(&self) -> Option<&str> {
85 match self {
86 Symbol::Const { description, .. } => description.as_deref(),
87 Symbol::Func { description, .. } => description.as_deref(),
88 }
89 }
90}
91
92#[derive(Debug, Default, Clone)]
107pub struct SymTable {
108 symbols: Vec<Symbol>,
109}
110
111impl SymTable {
112 pub fn new() -> Self {
114 Self::default()
115 }
116
117 pub fn stdlib() -> Self {
163 Self {
164 symbols: vec![
165 Symbol::Const {
167 name: "pi".into(),
168 value: Decimal::PI,
169 description: Some("π (3.14159...)".into()),
170 },
171 Symbol::Const {
172 name: "e".into(),
173 value: Decimal::E,
174 description: Some("Euler's number (2.71828...)".into()),
175 },
176 Symbol::Const {
177 name: "tau".into(),
178 value: Decimal::TWO_PI,
179 description: Some("2π (6.28318...)".into()),
180 },
181 Symbol::Const {
182 name: "ln2".into(),
183 value: Decimal::TWO.ln(),
184 description: Some("Natural logarithm of 2".into()),
185 },
186 Symbol::Const {
187 name: "ln10".into(),
188 value: Decimal::TEN.log10(),
189 description: Some("Natural logarithm of 10".into()),
190 },
191 Symbol::Const {
192 name: "sqrt2".into(),
193 value: Decimal::TWO.sqrt().unwrap(),
194 description: Some("Square root of 2".into()),
195 },
196 Symbol::Func {
198 name: "sin".into(),
199 args: 1,
200 variadic: false,
201 callback: |args| Ok(args[0].sin()),
202 description: Some("Sine".into()),
203 },
204 Symbol::Func {
205 name: "cos".into(),
206 args: 1,
207 variadic: false,
208 callback: |args| Ok(args[0].cos()),
209 description: Some("Cosine".into()),
210 },
211 Symbol::Func {
212 name: "tan".into(),
213 args: 1,
214 variadic: false,
215 callback: |args| {
216 let input = args[0];
217 match panic::catch_unwind(panic::AssertUnwindSafe(|| input.tan())) {
218 Ok(result) => Ok(result),
219 Err(_) => Err(FuncError::DomainError {
220 function: "tan".to_string(),
221 input,
222 }),
223 }
224 },
225 description: Some("Tangent".into()),
226 },
227 Symbol::Func {
228 name: "asin".into(),
229 args: 1,
230 variadic: false,
231 callback: |args| f64_calc_1(args, |x| x.asin()),
232 description: Some("Arcsine".into()),
233 },
234 Symbol::Func {
235 name: "acos".into(),
236 args: 1,
237 variadic: false,
238 callback: |args| f64_calc_1(args, |x| x.acos()),
239 description: Some("Arccosine".into()),
240 },
241 Symbol::Func {
242 name: "atan".into(),
243 args: 1,
244 variadic: false,
245 callback: |args| f64_calc_1(args, |x| x.atan()),
246 description: Some("Arctangent".into()),
247 },
248 Symbol::Func {
249 name: "atan2".into(),
250 args: 2,
251 variadic: false,
252 callback: |args| f64_calc_2(args, |y, x| y.atan2(x)),
253 description: Some("Two-argument arctangent".into()),
254 },
255 Symbol::Func {
256 name: "sinh".into(),
257 args: 1,
258 variadic: false,
259 callback: |args| f64_calc_1(args, |x| x.sinh()),
260 description: Some("Hyperbolic sine".into()),
261 },
262 Symbol::Func {
263 name: "cosh".into(),
264 args: 1,
265 variadic: false,
266 callback: |args| f64_calc_1(args, |x| x.cosh()),
267 description: Some("Hyperbolic cosine".into()),
268 },
269 Symbol::Func {
270 name: "tanh".into(),
271 args: 1,
272 variadic: false,
273 callback: |args| f64_calc_1(args, |x| x.tanh()),
274 description: Some("Hyperbolic tangent".into()),
275 },
276 Symbol::Func {
278 name: "sqrt".into(),
279 args: 1,
280 variadic: false,
281 callback: |args| {
282 args[0]
283 .sqrt()
284 .ok_or_else(|| FuncError::NegativeSqrt { value: args[0] })
285 },
286 description: Some("Square root".into()),
287 },
288 Symbol::Func {
289 name: "cbrt".into(),
290 args: 1,
291 variadic: false,
292 callback: |args| f64_calc_1(args, |x| x.cbrt()),
293 description: Some("Cube root".into()),
294 },
295 Symbol::Func {
296 name: "pow".into(),
297 args: 2,
298 variadic: false,
299 callback: |args| {
300 let base = args[0];
301 let exponent = args[1];
302 match panic::catch_unwind(panic::AssertUnwindSafe(|| base.powd(exponent))) {
303 Ok(result) => Ok(result),
304 Err(_) => Err(FuncError::MathError {
305 message: format!("Power operation failed: {}^{}", base, exponent),
306 }),
307 }
308 },
309 description: Some("x raised to power y".into()),
310 },
311 Symbol::Func {
313 name: "log".into(),
314 args: 1,
315 variadic: false,
316 callback: |args| {
317 if args[0] <= Decimal::ZERO {
318 Err(FuncError::DomainError {
319 function: "log".to_string(),
320 input: args[0],
321 })
322 } else {
323 Ok(args[0].ln())
324 }
325 },
326 description: Some("Natural logarithm".into()),
327 },
328 Symbol::Func {
329 name: "log2".into(),
330 args: 1,
331 variadic: false,
332 callback: |args| f64_calc_1(args, |x| x.log2()),
333 description: Some("Base-2 logarithm".into()),
334 },
335 Symbol::Func {
336 name: "log10".into(),
337 args: 1,
338 variadic: false,
339 callback: |args| {
340 if args[0] <= Decimal::ZERO {
341 Err(FuncError::DomainError {
342 function: "log10".to_string(),
343 input: args[0],
344 })
345 } else {
346 Ok(args[0].log10())
347 }
348 },
349 description: Some("Base-10 logarithm".into()),
350 },
351 Symbol::Func {
352 name: "exp".into(),
353 args: 1,
354 variadic: false,
355 callback: |args| {
356 let input = args[0];
357 match panic::catch_unwind(panic::AssertUnwindSafe(|| input.exp())) {
358 Ok(result) => Ok(result),
359 Err(_) => Err(FuncError::MathError {
360 message: "Exponential overflow or underflow".to_string(),
361 }),
362 }
363 },
364 description: Some("e raised to power x".into()),
365 },
366 Symbol::Func {
367 name: "exp2".into(),
368 args: 1,
369 variadic: false,
370 callback: |args| f64_calc_1(args, |x| x.exp2()),
371 description: Some("2 raised to power x".into()),
372 },
373 Symbol::Func {
375 name: "abs".into(),
376 args: 1,
377 variadic: false,
378 callback: |args| Ok(args[0].abs()),
379 description: Some("Absolute value".into()),
380 },
381 Symbol::Func {
382 name: "sign".into(),
383 args: 1,
384 variadic: false,
385 callback: |args| Ok(args[0].signum()),
386 description: Some("Sign function (-1, 0, or 1)".into()),
387 },
388 Symbol::Func {
389 name: "floor".into(),
390 args: 1,
391 variadic: false,
392 callback: |args| Ok(args[0].floor()),
393 description: Some("Floor function".into()),
394 },
395 Symbol::Func {
396 name: "ceil".into(),
397 args: 1,
398 variadic: false,
399 callback: |args| Ok(args[0].ceil()),
400 description: Some("Ceiling function".into()),
401 },
402 Symbol::Func {
403 name: "round".into(),
404 args: 1,
405 variadic: false,
406 callback: |args| Ok(args[0].round()),
407 description: Some("Round to nearest integer".into()),
408 },
409 Symbol::Func {
410 name: "trunc".into(),
411 args: 1,
412 variadic: false,
413 callback: |args| Ok(args[0].trunc()),
414 description: Some("Truncate to integer".into()),
415 },
416 Symbol::Func {
417 name: "fract".into(),
418 args: 1,
419 variadic: false,
420 callback: |args| Ok(args[0].fract()),
421 description: Some("Fractional part".into()),
422 },
423 Symbol::Func {
424 name: "mod".into(),
425 args: 2,
426 variadic: false,
427 callback: |args| Ok(args[0] % args[1]),
428 description: Some("Remainder of x/y".into()),
429 },
430 Symbol::Func {
431 name: "hypot".into(),
432 args: 2,
433 variadic: false,
434 callback: |args| f64_calc_2(args, |x, y| x.hypot(y)),
435 description: Some("Euclidean distance sqrt(x²+y²)".into()),
436 },
437 Symbol::Func {
438 name: "clamp".into(),
439 args: 3,
440 variadic: false,
441 callback: |args| Ok(args[0].clamp(args[1].min(args[2]), args[2].max(args[1]))),
442 description: Some("Constrain value between bounds".into()),
443 },
444 Symbol::Func {
445 name: "if".into(),
446 args: 3,
447 variadic: false,
448 callback: |args| {
449 if args[0] != Decimal::ZERO {
450 Ok(args[1])
451 } else {
452 Ok(args[2])
453 }
454 },
455 description: Some(
456 "Conditional expression: if(condition, true_value, false_value)".into(),
457 ),
458 },
459 Symbol::Func {
461 name: "min".into(),
462 args: 1,
463 variadic: true,
464 callback: |args| {
465 Ok(*args.iter().min().ok_or_else(|| FuncError::MathError {
466 message: "min() requires at least one argument".to_string(),
467 })?)
468 },
469 description: Some("Minimum value".into()),
470 },
471 Symbol::Func {
472 name: "max".into(),
473 args: 1,
474 variadic: true,
475 callback: |args| {
476 Ok(*args.iter().max().ok_or_else(|| FuncError::MathError {
477 message: "max() requires at least one argument".to_string(),
478 })?)
479 },
480 description: Some("Maximum value".into()),
481 },
482 Symbol::Func {
483 name: "sum".into(),
484 args: 1,
485 variadic: true,
486 callback: |args| Ok(args.iter().sum()),
487 description: Some("Sum of values".into()),
488 },
489 Symbol::Func {
490 name: "avg".into(),
491 args: 1,
492 variadic: true,
493 callback: |args| {
494 let sum: Decimal = args.iter().sum();
495 let count = Decimal::from(args.len());
496 Ok(sum / count)
497 },
498 description: Some("Average of values".into()),
499 },
500 ],
501 }
502 }
503
504 pub fn add_const<S: Into<Cow<'static, str>>>(
508 &mut self,
509 name: S,
510 value: Decimal,
511 ) -> Result<&mut Self, SymbolError> {
512 let name = name.into();
513 if self.get(&name).is_some() {
514 return Err(SymbolError::DuplicateSymbol(name.to_string()));
515 }
516 self.symbols.push(Symbol::Const {
517 name,
518 value,
519 description: None,
520 });
521 Ok(self)
522 }
523
524 pub fn add_func<S: Into<Cow<'static, str>>>(
534 &mut self,
535 name: S,
536 args: usize,
537 variadic: bool,
538 callback: fn(&[Decimal]) -> Result<Decimal, FuncError>,
539 ) -> Result<&mut Self, SymbolError> {
540 let name = name.into();
541 if self.get(&name).is_some() {
542 return Err(SymbolError::DuplicateSymbol(name.to_string()));
543 }
544 self.symbols.push(Symbol::Func {
545 name,
546 args,
547 variadic,
548 callback,
549 description: None,
550 });
551 Ok(self)
552 }
553
554 pub fn get(&self, name: &str) -> Option<&Symbol> {
556 self.symbols
557 .iter()
558 .find(|sym| sym.name().eq_ignore_ascii_case(name))
559 }
560
561 pub fn symbols(&self) -> impl Iterator<Item = &Symbol> {
563 self.symbols.iter()
564 }
565}