1use arithmetic_parser::{
2 grammars::{Features, NumGrammar, Parse, Untyped},
3 BinaryOp, Block, Expr, Lvalue, Spanned, SpannedExpr, Statement, UnaryOp,
4};
5use num_complex::Complex32;
6use thiserror::Error;
7
8use std::{collections::HashSet, error::Error, fmt, iter, mem, ops, str::FromStr};
9
10#[derive(Debug)]
12#[cfg_attr(
13 docsrs,
14 doc(cfg(any(
15 feature = "dyn_cpu_backend",
16 feature = "opencl_backend",
17 feature = "vulkan_backend"
18 )))
19)]
20pub struct FnError {
21 fragment: String,
22 line: u32,
23 column: usize,
24 source: ErrorSource,
25}
26
27#[derive(Debug)]
28enum ErrorSource {
29 Parse(String),
30 Eval(EvalError),
31}
32
33impl fmt::Display for ErrorSource {
34 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
35 match self {
36 Self::Parse(err) => write!(formatter, "[PARSE] {}", err),
37 Self::Eval(err) => write!(formatter, "[EVAL] {}", err),
38 }
39 }
40}
41
42#[derive(Debug, Error)]
43pub(crate) enum EvalError {
44 #[error("Last statement in function body is not an expression")]
45 NoReturn,
46 #[error("Useless expression")]
47 UselessExpr,
48 #[error("Cannot redefine variable")]
49 RedefinedVar,
50 #[error("Undefined variable")]
51 UndefinedVar,
52 #[error("Undefined function")]
53 UndefinedFn,
54 #[error("Function call has bogus arity")]
55 FnArity,
56 #[error("Unsupported language construct")]
57 Unsupported,
58}
59
60#[derive(Debug, Clone, Copy, PartialEq)]
61pub(crate) enum UnaryFunction {
62 Arg,
63 Sqrt,
64 Exp,
65 Log,
66 Sinh,
67 Cosh,
68 Tanh,
69 Asinh,
70 Acosh,
71 Atanh,
72}
73
74impl UnaryFunction {
75 #[cfg(any(feature = "opencl_backend", feature = "vulkan_backend"))]
76 pub fn as_str(self) -> &'static str {
77 match self {
78 Self::Arg => "arg",
79 Self::Sqrt => "sqrt",
80 Self::Exp => "exp",
81 Self::Log => "log",
82 Self::Sinh => "sinh",
83 Self::Cosh => "cosh",
84 Self::Tanh => "tanh",
85 Self::Asinh => "asinh",
86 Self::Acosh => "acosh",
87 Self::Atanh => "atanh",
88 }
89 }
90
91 #[cfg(feature = "dyn_cpu_backend")]
92 pub fn eval(self, arg: Complex32) -> Complex32 {
93 match self {
94 Self::Arg => Complex32::new(arg.arg(), 0.0),
95 Self::Sqrt => arg.sqrt(),
96 Self::Exp => arg.exp(),
97 Self::Log => arg.ln(),
98 Self::Sinh => arg.sinh(),
99 Self::Cosh => arg.cosh(),
100 Self::Tanh => arg.tanh(),
101 Self::Asinh => arg.asinh(),
102 Self::Acosh => arg.acosh(),
103 Self::Atanh => arg.atanh(),
104 }
105 }
106}
107
108impl FromStr for UnaryFunction {
109 type Err = EvalError;
110
111 fn from_str(s: &str) -> Result<Self, Self::Err> {
112 match s {
113 "arg" => Ok(Self::Arg),
114 "sqrt" => Ok(Self::Sqrt),
115 "exp" => Ok(Self::Exp),
116 "log" => Ok(Self::Log),
117 "sinh" => Ok(Self::Sinh),
118 "cosh" => Ok(Self::Cosh),
119 "tanh" => Ok(Self::Tanh),
120 "asinh" => Ok(Self::Asinh),
121 "acosh" => Ok(Self::Acosh),
122 "atanh" => Ok(Self::Atanh),
123 _ => Err(EvalError::UndefinedFn),
124 }
125 }
126}
127
128#[derive(Debug, Clone, PartialEq)]
129pub(crate) enum Evaluated {
130 Value(Complex32),
131 Variable(String),
132 Negation(Box<Evaluated>),
133 Binary {
134 op: BinaryOp,
135 lhs: Box<Evaluated>,
136 rhs: Box<Evaluated>,
137 },
138 FunctionCall {
139 function: UnaryFunction,
140 arg: Box<Evaluated>,
141 },
142}
143
144impl Evaluated {
145 fn is_commutative(op: BinaryOp) -> bool {
146 matches!(op, BinaryOp::Add | BinaryOp::Mul)
147 }
148
149 fn is_commutative_pair(op: BinaryOp, other_op: BinaryOp) -> bool {
150 op.priority() == other_op.priority() && op != BinaryOp::Power
151 }
152
153 fn fold(mut op: BinaryOp, mut lhs: Self, mut rhs: Self) -> Self {
154 if let (Self::Value(lhs_val), Self::Value(rhs_val)) = (&lhs, &rhs) {
157 return Self::Value(match op {
158 BinaryOp::Add => *lhs_val + *rhs_val,
159 BinaryOp::Sub => *lhs_val - *rhs_val,
160 BinaryOp::Mul => *lhs_val * *rhs_val,
161 BinaryOp::Div => *lhs_val / *rhs_val,
162 BinaryOp::Power => lhs_val.powc(*rhs_val),
163 _ => unreachable!(),
164 });
165 }
166
167 if let Self::Value(val) = rhs {
168 match op {
173 BinaryOp::Sub => {
174 op = BinaryOp::Add;
175 rhs = Self::Value(-val);
176 }
177 BinaryOp::Div => {
178 op = BinaryOp::Mul;
179 rhs = Self::Value(1.0 / val);
180 }
181 _ => { }
182 }
183 } else if let Self::Value(_) = lhs {
184 if Self::is_commutative(op) {
188 mem::swap(&mut lhs, &mut rhs);
189 }
190 }
191
192 if let Self::Binary {
193 op: inner_op,
194 rhs: inner_rhs,
195 ..
196 } = &mut lhs
197 {
198 if Self::is_commutative_pair(*inner_op, op) {
199 if let Self::Value(inner_val) = **inner_rhs {
200 if let Self::Value(val) = rhs {
201 let new_rhs = match op {
209 BinaryOp::Add => inner_val + val,
210 BinaryOp::Mul => inner_val * val,
211 _ => unreachable!(),
212 };
214
215 *inner_rhs = Box::new(Self::Value(new_rhs));
216 return lhs;
217 } else {
218 mem::swap(&mut rhs, inner_rhs);
221 mem::swap(&mut op, inner_op);
222 }
223 }
224 }
225 }
226
227 Self::Binary {
228 op,
229 lhs: Box::new(lhs),
230 rhs: Box::new(rhs),
231 }
232 }
233}
234
235impl ops::Neg for Evaluated {
236 type Output = Self;
237
238 fn neg(self) -> Self::Output {
239 match self {
240 Self::Value(val) => Self::Value(-val),
241 Self::Negation(inner) => *inner,
242 other => Self::Negation(Box::new(other)),
243 }
244 }
245}
246
247impl FnError {
248 fn parse(source: &arithmetic_parser::Error<'_>) -> Self {
249 let column = source.span().get_column();
250 Self {
251 fragment: (*source.span().fragment()).to_owned(),
252 line: source.span().location_line(),
253 column,
254 source: ErrorSource::Parse(source.kind().to_string()),
255 }
256 }
257
258 fn eval<T>(span: &Spanned<'_, T>, source: EvalError) -> Self {
259 let column = span.get_column();
260 Self {
261 fragment: (*span.fragment()).to_owned(),
262 line: span.location_line(),
263 column,
264 source: ErrorSource::Eval(source),
265 }
266 }
267}
268
269impl fmt::Display for FnError {
270 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
271 write!(formatter, "{}:{}: {}", self.line, self.column, self.source)?;
272 if formatter.alternate() {
273 formatter.write_str("\n")?;
274 formatter.pad(&self.fragment)?;
275 }
276 Ok(())
277 }
278}
279
280impl Error for FnError {
281 fn source(&self) -> Option<&(dyn Error + 'static)> {
282 match &self.source {
283 ErrorSource::Eval(e) => Some(e),
284 _ => None,
285 }
286 }
287}
288
289type FnGrammarBase = Untyped<NumGrammar<Complex32>>;
290
291#[derive(Debug, Clone, Copy)]
292struct FnGrammar;
293
294impl Parse for FnGrammar {
295 type Base = FnGrammarBase;
296 const FEATURES: Features = Features::empty();
297}
298
299#[derive(Debug)]
300pub(crate) struct Context {
301 variables: HashSet<String>,
302}
303
304impl Context {
305 pub(crate) fn new(arg_name: &str) -> Self {
306 Self {
307 variables: iter::once(arg_name.to_owned()).collect(),
308 }
309 }
310
311 fn process(
312 &mut self,
313 block: &Block<'_, FnGrammarBase>,
314 total_span: Spanned<'_>,
315 ) -> Result<Function, FnError> {
316 let mut assignments = vec![];
317 for statement in &block.statements {
318 match &statement.extra {
319 Statement::Assignment { lhs, rhs } => {
320 let variable_name = match lhs.extra {
321 Lvalue::Variable { .. } => *lhs.fragment(),
322 _ => unreachable!("Tuples are disabled in parser"),
323 };
324
325 if self.variables.contains(variable_name) {
326 let err = FnError::eval(lhs, EvalError::RedefinedVar);
327 return Err(err);
328 }
329
330 let value = self.eval_expr(rhs)?;
332 self.variables.insert(variable_name.to_owned());
333 assignments.push((variable_name.to_owned(), value));
334 }
335
336 Statement::Expr(_) => {
337 return Err(FnError::eval(&statement, EvalError::UselessExpr));
338 }
339
340 _ => return Err(FnError::eval(&statement, EvalError::Unsupported)),
341 }
342 }
343
344 let return_value = block
345 .return_value
346 .as_ref()
347 .ok_or_else(|| FnError::eval(&total_span, EvalError::NoReturn))?;
348 let value = self.eval_expr(return_value)?;
349 assignments.push((String::new(), value));
350
351 Ok(Function { assignments })
352 }
353
354 fn eval_expr(&self, expr: &SpannedExpr<'_, FnGrammarBase>) -> Result<Evaluated, FnError> {
355 match &expr.extra {
356 Expr::Variable => {
357 let var_name = *expr.fragment();
358 self.variables
359 .get(var_name)
360 .ok_or_else(|| FnError::eval(expr, EvalError::UndefinedVar))?;
361
362 Ok(Evaluated::Variable(var_name.to_owned()))
363 }
364 Expr::Literal(lit) => Ok(Evaluated::Value(*lit)),
365
366 Expr::Unary { op, inner } => match op.extra {
367 UnaryOp::Neg => Ok(-self.eval_expr(inner)?),
368 _ => Err(FnError::eval(op, EvalError::Unsupported)),
369 },
370
371 Expr::Binary { lhs, op, rhs } => {
372 let lhs_value = self.eval_expr(lhs)?;
373 let rhs_value = self.eval_expr(rhs)?;
374
375 Ok(match op.extra {
376 BinaryOp::Add
377 | BinaryOp::Sub
378 | BinaryOp::Mul
379 | BinaryOp::Div
380 | BinaryOp::Power => Evaluated::fold(op.extra, lhs_value, rhs_value),
381 _ => {
382 return Err(FnError::eval(op, EvalError::Unsupported));
383 }
384 })
385 }
386
387 Expr::Function { name, args } => {
388 let fn_name = *name.fragment();
389 let function: UnaryFunction =
390 fn_name.parse().map_err(|e| FnError::eval(name, e))?;
391
392 if args.len() != 1 {
393 return Err(FnError::eval(expr, EvalError::FnArity));
394 }
395
396 Ok(Evaluated::FunctionCall {
397 function,
398 arg: Box::new(self.eval_expr(&args[0])?),
399 })
400 }
401
402 Expr::FnDefinition(_) | Expr::Block(_) | Expr::Tuple(_) | Expr::Method { .. } => {
403 unreachable!("Disabled in parser")
404 }
405
406 _ => Err(FnError::eval(expr, EvalError::Unsupported)),
407 }
408 }
409}
410
411#[cfg_attr(
438 docsrs,
439 doc(cfg(any(
440 feature = "dyn_cpu_backend",
441 feature = "opencl_backend",
442 feature = "vulkan_backend"
443 )))
444)]
445#[derive(Debug, Clone)]
446pub struct Function {
447 assignments: Vec<(String, Evaluated)>,
448}
449
450impl Function {
451 pub(crate) fn assignments(&self) -> impl Iterator<Item = (&str, &Evaluated)> + '_ {
452 self.assignments.iter().filter_map(|(name, value)| {
453 if name.is_empty() {
454 None
455 } else {
456 Some((name.as_str(), value))
457 }
458 })
459 }
460
461 pub(crate) fn return_value(&self) -> &Evaluated {
462 &self.assignments.last().unwrap().1
463 }
464}
465
466impl FromStr for Function {
467 type Err = FnError;
468
469 fn from_str(s: &str) -> Result<Self, Self::Err> {
470 let statements = FnGrammar::parse_statements(s).map_err(|e| FnError::parse(&e))?;
471 let body_span = Spanned::from_str(s, ..);
472 Context::new("z").process(&statements, body_span)
473 }
474}
475
476#[cfg(test)]
477mod tests {
478 use super::*;
479
480 fn z_square() -> Box<Evaluated> {
481 Box::new(Evaluated::Binary {
482 op: BinaryOp::Mul,
483 lhs: Box::new(Evaluated::Variable("z".to_owned())),
484 rhs: Box::new(Evaluated::Variable("z".to_owned())),
485 })
486 }
487
488 #[test]
489 fn simple_function() {
490 let function: Function = "z*z + (0.77 - 0.2i)".parse().unwrap();
491 let expected_expr = Evaluated::Binary {
492 op: BinaryOp::Add,
493 lhs: z_square(),
494 rhs: Box::new(Evaluated::Value(Complex32::new(0.77, -0.2))),
495 };
496 assert_eq!(function.assignments, vec![(String::new(), expected_expr)]);
497 }
498
499 #[test]
500 fn simple_function_with_rewrite_rules() {
501 let function: Function = "z / 0.25 - 0.1i + (0.77 - 0.1i)".parse().unwrap();
502 let expected_expr = Evaluated::Binary {
503 op: BinaryOp::Add,
504 lhs: Box::new(Evaluated::Binary {
505 op: BinaryOp::Mul,
506 lhs: Box::new(Evaluated::Variable("z".to_owned())),
507 rhs: Box::new(Evaluated::Value(Complex32::new(4.0, 0.0))),
508 }),
509 rhs: Box::new(Evaluated::Value(Complex32::new(0.77, -0.2))),
510 };
511 assert_eq!(function.assignments, vec![(String::new(), expected_expr)]);
512 }
513
514 #[test]
515 fn function_with_several_rewrite_rules() {
516 let function: Function = "z + 0.1 - z*z + 0.3i".parse().unwrap();
517 let expected_expr = Evaluated::Binary {
518 op: BinaryOp::Add,
519 lhs: Box::new(Evaluated::Binary {
520 op: BinaryOp::Sub,
521 lhs: Box::new(Evaluated::Variable("z".to_owned())),
522 rhs: z_square(),
523 }),
524 rhs: Box::new(Evaluated::Value(Complex32::new(0.1, 0.3))),
525 };
526 assert_eq!(function.assignments, vec![(String::new(), expected_expr)]);
527 }
528
529 #[test]
530 fn simple_function_with_mul_rewrite_rules() {
531 let function: Function = "sinh(z - 5) / 4. * 6i".parse().unwrap();
532 let expected_expr = Evaluated::Binary {
533 op: BinaryOp::Mul,
534 lhs: Box::new(Evaluated::FunctionCall {
535 function: UnaryFunction::Sinh,
536 arg: Box::new(Evaluated::Binary {
537 op: BinaryOp::Add,
538 lhs: Box::new(Evaluated::Variable("z".to_owned())),
539 rhs: Box::new(Evaluated::Value(Complex32::new(-5.0, 0.0))),
540 }),
541 }),
542 rhs: Box::new(Evaluated::Value(Complex32::new(0.0, 1.5))),
543 };
544 assert_eq!(function.assignments, vec![(String::new(), expected_expr)]);
545 }
546
547 #[test]
548 fn simple_function_with_redistribution() {
549 let function: Function = "0.5 + sinh(z) - 0.2i".parse().unwrap();
550 let expected_expr = Evaluated::Binary {
551 op: BinaryOp::Add,
552 lhs: Box::new(Evaluated::FunctionCall {
553 function: UnaryFunction::Sinh,
554 arg: Box::new(Evaluated::Variable("z".to_owned())),
555 }),
556 rhs: Box::new(Evaluated::Value(Complex32::new(0.5, -0.2))),
557 };
558 assert_eq!(function.assignments, vec![(String::new(), expected_expr)]);
559 }
560
561 #[test]
562 fn function_with_assignments() {
563 let function: Function = "c = 0.5 - 0.2i; z*z + c".parse().unwrap();
564 let expected_expr = Evaluated::Binary {
565 op: BinaryOp::Add,
566 lhs: z_square(),
567 rhs: Box::new(Evaluated::Variable("c".to_owned())),
568 };
569
570 assert_eq!(
571 function.assignments,
572 vec![
573 ("c".to_owned(), Evaluated::Value(Complex32::new(0.5, -0.2))),
574 (String::new(), expected_expr),
575 ]
576 );
577 }
578}