julia_set/
function.rs

1use arithmetic_parser::{
2    grammars::{Features, NumGrammar, Parse, Untyped},
3    BinaryOp, Block, Expr, Lvalue, Spanned, SpannedExpr, Statement, UnaryOp,
4};
5use num_complex::Complex32;
6use thiserror::Error;
7
8use std::{collections::HashSet, error::Error, fmt, iter, mem, ops, str::FromStr};
9
10/// Error associated with creating a [`Function`].
11#[derive(Debug)]
12#[cfg_attr(
13    docsrs,
14    doc(cfg(any(
15        feature = "dyn_cpu_backend",
16        feature = "opencl_backend",
17        feature = "vulkan_backend"
18    )))
19)]
20pub struct FnError {
21    fragment: String,
22    line: u32,
23    column: usize,
24    source: ErrorSource,
25}
26
27#[derive(Debug)]
28enum ErrorSource {
29    Parse(String),
30    Eval(EvalError),
31}
32
33impl fmt::Display for ErrorSource {
34    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
35        match self {
36            Self::Parse(err) => write!(formatter, "[PARSE] {}", err),
37            Self::Eval(err) => write!(formatter, "[EVAL] {}", err),
38        }
39    }
40}
41
42#[derive(Debug, Error)]
43pub(crate) enum EvalError {
44    #[error("Last statement in function body is not an expression")]
45    NoReturn,
46    #[error("Useless expression")]
47    UselessExpr,
48    #[error("Cannot redefine variable")]
49    RedefinedVar,
50    #[error("Undefined variable")]
51    UndefinedVar,
52    #[error("Undefined function")]
53    UndefinedFn,
54    #[error("Function call has bogus arity")]
55    FnArity,
56    #[error("Unsupported language construct")]
57    Unsupported,
58}
59
60#[derive(Debug, Clone, Copy, PartialEq)]
61pub(crate) enum UnaryFunction {
62    Arg,
63    Sqrt,
64    Exp,
65    Log,
66    Sinh,
67    Cosh,
68    Tanh,
69    Asinh,
70    Acosh,
71    Atanh,
72}
73
74impl UnaryFunction {
75    #[cfg(any(feature = "opencl_backend", feature = "vulkan_backend"))]
76    pub fn as_str(self) -> &'static str {
77        match self {
78            Self::Arg => "arg",
79            Self::Sqrt => "sqrt",
80            Self::Exp => "exp",
81            Self::Log => "log",
82            Self::Sinh => "sinh",
83            Self::Cosh => "cosh",
84            Self::Tanh => "tanh",
85            Self::Asinh => "asinh",
86            Self::Acosh => "acosh",
87            Self::Atanh => "atanh",
88        }
89    }
90
91    #[cfg(feature = "dyn_cpu_backend")]
92    pub fn eval(self, arg: Complex32) -> Complex32 {
93        match self {
94            Self::Arg => Complex32::new(arg.arg(), 0.0),
95            Self::Sqrt => arg.sqrt(),
96            Self::Exp => arg.exp(),
97            Self::Log => arg.ln(),
98            Self::Sinh => arg.sinh(),
99            Self::Cosh => arg.cosh(),
100            Self::Tanh => arg.tanh(),
101            Self::Asinh => arg.asinh(),
102            Self::Acosh => arg.acosh(),
103            Self::Atanh => arg.atanh(),
104        }
105    }
106}
107
108impl FromStr for UnaryFunction {
109    type Err = EvalError;
110
111    fn from_str(s: &str) -> Result<Self, Self::Err> {
112        match s {
113            "arg" => Ok(Self::Arg),
114            "sqrt" => Ok(Self::Sqrt),
115            "exp" => Ok(Self::Exp),
116            "log" => Ok(Self::Log),
117            "sinh" => Ok(Self::Sinh),
118            "cosh" => Ok(Self::Cosh),
119            "tanh" => Ok(Self::Tanh),
120            "asinh" => Ok(Self::Asinh),
121            "acosh" => Ok(Self::Acosh),
122            "atanh" => Ok(Self::Atanh),
123            _ => Err(EvalError::UndefinedFn),
124        }
125    }
126}
127
128#[derive(Debug, Clone, PartialEq)]
129pub(crate) enum Evaluated {
130    Value(Complex32),
131    Variable(String),
132    Negation(Box<Evaluated>),
133    Binary {
134        op: BinaryOp,
135        lhs: Box<Evaluated>,
136        rhs: Box<Evaluated>,
137    },
138    FunctionCall {
139        function: UnaryFunction,
140        arg: Box<Evaluated>,
141    },
142}
143
144impl Evaluated {
145    fn is_commutative(op: BinaryOp) -> bool {
146        matches!(op, BinaryOp::Add | BinaryOp::Mul)
147    }
148
149    fn is_commutative_pair(op: BinaryOp, other_op: BinaryOp) -> bool {
150        op.priority() == other_op.priority() && op != BinaryOp::Power
151    }
152
153    fn fold(mut op: BinaryOp, mut lhs: Self, mut rhs: Self) -> Self {
154        // First, check if the both operands are values. In this case, we can eagerly compute
155        // the resulting value.
156        if let (Self::Value(lhs_val), Self::Value(rhs_val)) = (&lhs, &rhs) {
157            return Self::Value(match op {
158                BinaryOp::Add => *lhs_val + *rhs_val,
159                BinaryOp::Sub => *lhs_val - *rhs_val,
160                BinaryOp::Mul => *lhs_val * *rhs_val,
161                BinaryOp::Div => *lhs_val / *rhs_val,
162                BinaryOp::Power => lhs_val.powc(*rhs_val),
163                _ => unreachable!(),
164            });
165        }
166
167        if let Self::Value(val) = rhs {
168            // Convert an RHS value to use a commutative op (e.g., `+` instead of `-`).
169            // This will come in handy during later transforms.
170            //
171            // For example, this will transform `z - 1` into `z + -1`.
172            match op {
173                BinaryOp::Sub => {
174                    op = BinaryOp::Add;
175                    rhs = Self::Value(-val);
176                }
177                BinaryOp::Div => {
178                    op = BinaryOp::Mul;
179                    rhs = Self::Value(1.0 / val);
180                }
181                _ => { /* do nothing */ }
182            }
183        } else if let Self::Value(_) = lhs {
184            // Swap LHS and RHS to move the value to the right.
185            //
186            // For example, this will transform `1 + z` into `z + 1`.
187            if Self::is_commutative(op) {
188                mem::swap(&mut lhs, &mut rhs);
189            }
190        }
191
192        if let Self::Binary {
193            op: inner_op,
194            rhs: inner_rhs,
195            ..
196        } = &mut lhs
197        {
198            if Self::is_commutative_pair(*inner_op, op) {
199                if let Self::Value(inner_val) = **inner_rhs {
200                    if let Self::Value(val) = rhs {
201                        // Make the following replacement:
202                        //
203                        //    op             op
204                        //   /  \           /  \
205                        //  op  c   ---->  a  b op c
206                        // /  \
207                        // a  b
208                        let new_rhs = match op {
209                            BinaryOp::Add => inner_val + val,
210                            BinaryOp::Mul => inner_val * val,
211                            _ => unreachable!(),
212                            // ^-- We've replaced '-' and '/' `op`s previously.
213                        };
214
215                        *inner_rhs = Box::new(Self::Value(new_rhs));
216                        return lhs;
217                    } else {
218                        // Switch `inner_rhs` and `rhs`, moving a `Value` to the right.
219                        // For example, this will replace `z + 1 - z^2` to `z - z^2 + 1`.
220                        mem::swap(&mut rhs, inner_rhs);
221                        mem::swap(&mut op, inner_op);
222                    }
223                }
224            }
225        }
226
227        Self::Binary {
228            op,
229            lhs: Box::new(lhs),
230            rhs: Box::new(rhs),
231        }
232    }
233}
234
235impl ops::Neg for Evaluated {
236    type Output = Self;
237
238    fn neg(self) -> Self::Output {
239        match self {
240            Self::Value(val) => Self::Value(-val),
241            Self::Negation(inner) => *inner,
242            other => Self::Negation(Box::new(other)),
243        }
244    }
245}
246
247impl FnError {
248    fn parse(source: &arithmetic_parser::Error<'_>) -> Self {
249        let column = source.span().get_column();
250        Self {
251            fragment: (*source.span().fragment()).to_owned(),
252            line: source.span().location_line(),
253            column,
254            source: ErrorSource::Parse(source.kind().to_string()),
255        }
256    }
257
258    fn eval<T>(span: &Spanned<'_, T>, source: EvalError) -> Self {
259        let column = span.get_column();
260        Self {
261            fragment: (*span.fragment()).to_owned(),
262            line: span.location_line(),
263            column,
264            source: ErrorSource::Eval(source),
265        }
266    }
267}
268
269impl fmt::Display for FnError {
270    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
271        write!(formatter, "{}:{}: {}", self.line, self.column, self.source)?;
272        if formatter.alternate() {
273            formatter.write_str("\n")?;
274            formatter.pad(&self.fragment)?;
275        }
276        Ok(())
277    }
278}
279
280impl Error for FnError {
281    fn source(&self) -> Option<&(dyn Error + 'static)> {
282        match &self.source {
283            ErrorSource::Eval(e) => Some(e),
284            _ => None,
285        }
286    }
287}
288
289type FnGrammarBase = Untyped<NumGrammar<Complex32>>;
290
291#[derive(Debug, Clone, Copy)]
292struct FnGrammar;
293
294impl Parse for FnGrammar {
295    type Base = FnGrammarBase;
296    const FEATURES: Features = Features::empty();
297}
298
299#[derive(Debug)]
300pub(crate) struct Context {
301    variables: HashSet<String>,
302}
303
304impl Context {
305    pub(crate) fn new(arg_name: &str) -> Self {
306        Self {
307            variables: iter::once(arg_name.to_owned()).collect(),
308        }
309    }
310
311    fn process(
312        &mut self,
313        block: &Block<'_, FnGrammarBase>,
314        total_span: Spanned<'_>,
315    ) -> Result<Function, FnError> {
316        let mut assignments = vec![];
317        for statement in &block.statements {
318            match &statement.extra {
319                Statement::Assignment { lhs, rhs } => {
320                    let variable_name = match lhs.extra {
321                        Lvalue::Variable { .. } => *lhs.fragment(),
322                        _ => unreachable!("Tuples are disabled in parser"),
323                    };
324
325                    if self.variables.contains(variable_name) {
326                        let err = FnError::eval(lhs, EvalError::RedefinedVar);
327                        return Err(err);
328                    }
329
330                    // Evaluate the RHS.
331                    let value = self.eval_expr(rhs)?;
332                    self.variables.insert(variable_name.to_owned());
333                    assignments.push((variable_name.to_owned(), value));
334                }
335
336                Statement::Expr(_) => {
337                    return Err(FnError::eval(&statement, EvalError::UselessExpr));
338                }
339
340                _ => return Err(FnError::eval(&statement, EvalError::Unsupported)),
341            }
342        }
343
344        let return_value = block
345            .return_value
346            .as_ref()
347            .ok_or_else(|| FnError::eval(&total_span, EvalError::NoReturn))?;
348        let value = self.eval_expr(return_value)?;
349        assignments.push((String::new(), value));
350
351        Ok(Function { assignments })
352    }
353
354    fn eval_expr(&self, expr: &SpannedExpr<'_, FnGrammarBase>) -> Result<Evaluated, FnError> {
355        match &expr.extra {
356            Expr::Variable => {
357                let var_name = *expr.fragment();
358                self.variables
359                    .get(var_name)
360                    .ok_or_else(|| FnError::eval(expr, EvalError::UndefinedVar))?;
361
362                Ok(Evaluated::Variable(var_name.to_owned()))
363            }
364            Expr::Literal(lit) => Ok(Evaluated::Value(*lit)),
365
366            Expr::Unary { op, inner } => match op.extra {
367                UnaryOp::Neg => Ok(-self.eval_expr(inner)?),
368                _ => Err(FnError::eval(op, EvalError::Unsupported)),
369            },
370
371            Expr::Binary { lhs, op, rhs } => {
372                let lhs_value = self.eval_expr(lhs)?;
373                let rhs_value = self.eval_expr(rhs)?;
374
375                Ok(match op.extra {
376                    BinaryOp::Add
377                    | BinaryOp::Sub
378                    | BinaryOp::Mul
379                    | BinaryOp::Div
380                    | BinaryOp::Power => Evaluated::fold(op.extra, lhs_value, rhs_value),
381                    _ => {
382                        return Err(FnError::eval(op, EvalError::Unsupported));
383                    }
384                })
385            }
386
387            Expr::Function { name, args } => {
388                let fn_name = *name.fragment();
389                let function: UnaryFunction =
390                    fn_name.parse().map_err(|e| FnError::eval(name, e))?;
391
392                if args.len() != 1 {
393                    return Err(FnError::eval(expr, EvalError::FnArity));
394                }
395
396                Ok(Evaluated::FunctionCall {
397                    function,
398                    arg: Box::new(self.eval_expr(&args[0])?),
399                })
400            }
401
402            Expr::FnDefinition(_) | Expr::Block(_) | Expr::Tuple(_) | Expr::Method { .. } => {
403                unreachable!("Disabled in parser")
404            }
405
406            _ => Err(FnError::eval(expr, EvalError::Unsupported)),
407        }
408    }
409}
410
411/// Parsed complex-valued function of a single variable.
412///
413/// A `Function` instance can be created using [`FromStr`] trait. A function must use `z`
414/// as the (only) argument. A function may use arithmetic operations (`+`, `-`, `*`, `/`, `^`)
415/// and/or predefined unary functions:
416///
417/// - General functions: `arg`, `sqrt`, `exp`, `log`
418/// - Hyperbolic trigonometry: `sinh`, `cosh`, `tanh`
419/// - Inverse hyperbolic trigonometry: `asinh`, `acosh`, `atanh`
420///
421/// A function may define local variable assignment(s). The assignment syntax is similar to Python
422/// (or Rust, just without the `let` keyword): variable name followed by `=` and then by
423/// the arithmetic expression. Assignments must be separated by semicolons `;`. As in Rust,
424/// the last expression in function body is its return value.
425///
426/// # Examples
427///
428/// ```
429/// # use julia_set::Function;
430/// # fn main() -> anyhow::Result<()> {
431/// let function: Function = "z * z - 0.5".parse()?;
432/// let fn_with_calls: Function = "0.8 * z + z / atanh(z ^ -4)".parse()?;
433/// let fn_with_vars: Function = "c = -0.5 + 0.4i; z * z + c".parse()?;
434/// # Ok(())
435/// # }
436/// ```
437#[cfg_attr(
438    docsrs,
439    doc(cfg(any(
440        feature = "dyn_cpu_backend",
441        feature = "opencl_backend",
442        feature = "vulkan_backend"
443    )))
444)]
445#[derive(Debug, Clone)]
446pub struct Function {
447    assignments: Vec<(String, Evaluated)>,
448}
449
450impl Function {
451    pub(crate) fn assignments(&self) -> impl Iterator<Item = (&str, &Evaluated)> + '_ {
452        self.assignments.iter().filter_map(|(name, value)| {
453            if name.is_empty() {
454                None
455            } else {
456                Some((name.as_str(), value))
457            }
458        })
459    }
460
461    pub(crate) fn return_value(&self) -> &Evaluated {
462        &self.assignments.last().unwrap().1
463    }
464}
465
466impl FromStr for Function {
467    type Err = FnError;
468
469    fn from_str(s: &str) -> Result<Self, Self::Err> {
470        let statements = FnGrammar::parse_statements(s).map_err(|e| FnError::parse(&e))?;
471        let body_span = Spanned::from_str(s, ..);
472        Context::new("z").process(&statements, body_span)
473    }
474}
475
476#[cfg(test)]
477mod tests {
478    use super::*;
479
480    fn z_square() -> Box<Evaluated> {
481        Box::new(Evaluated::Binary {
482            op: BinaryOp::Mul,
483            lhs: Box::new(Evaluated::Variable("z".to_owned())),
484            rhs: Box::new(Evaluated::Variable("z".to_owned())),
485        })
486    }
487
488    #[test]
489    fn simple_function() {
490        let function: Function = "z*z + (0.77 - 0.2i)".parse().unwrap();
491        let expected_expr = Evaluated::Binary {
492            op: BinaryOp::Add,
493            lhs: z_square(),
494            rhs: Box::new(Evaluated::Value(Complex32::new(0.77, -0.2))),
495        };
496        assert_eq!(function.assignments, vec![(String::new(), expected_expr)]);
497    }
498
499    #[test]
500    fn simple_function_with_rewrite_rules() {
501        let function: Function = "z / 0.25 - 0.1i + (0.77 - 0.1i)".parse().unwrap();
502        let expected_expr = Evaluated::Binary {
503            op: BinaryOp::Add,
504            lhs: Box::new(Evaluated::Binary {
505                op: BinaryOp::Mul,
506                lhs: Box::new(Evaluated::Variable("z".to_owned())),
507                rhs: Box::new(Evaluated::Value(Complex32::new(4.0, 0.0))),
508            }),
509            rhs: Box::new(Evaluated::Value(Complex32::new(0.77, -0.2))),
510        };
511        assert_eq!(function.assignments, vec![(String::new(), expected_expr)]);
512    }
513
514    #[test]
515    fn function_with_several_rewrite_rules() {
516        let function: Function = "z + 0.1 - z*z + 0.3i".parse().unwrap();
517        let expected_expr = Evaluated::Binary {
518            op: BinaryOp::Add,
519            lhs: Box::new(Evaluated::Binary {
520                op: BinaryOp::Sub,
521                lhs: Box::new(Evaluated::Variable("z".to_owned())),
522                rhs: z_square(),
523            }),
524            rhs: Box::new(Evaluated::Value(Complex32::new(0.1, 0.3))),
525        };
526        assert_eq!(function.assignments, vec![(String::new(), expected_expr)]);
527    }
528
529    #[test]
530    fn simple_function_with_mul_rewrite_rules() {
531        let function: Function = "sinh(z - 5) / 4. * 6i".parse().unwrap();
532        let expected_expr = Evaluated::Binary {
533            op: BinaryOp::Mul,
534            lhs: Box::new(Evaluated::FunctionCall {
535                function: UnaryFunction::Sinh,
536                arg: Box::new(Evaluated::Binary {
537                    op: BinaryOp::Add,
538                    lhs: Box::new(Evaluated::Variable("z".to_owned())),
539                    rhs: Box::new(Evaluated::Value(Complex32::new(-5.0, 0.0))),
540                }),
541            }),
542            rhs: Box::new(Evaluated::Value(Complex32::new(0.0, 1.5))),
543        };
544        assert_eq!(function.assignments, vec![(String::new(), expected_expr)]);
545    }
546
547    #[test]
548    fn simple_function_with_redistribution() {
549        let function: Function = "0.5 + sinh(z) - 0.2i".parse().unwrap();
550        let expected_expr = Evaluated::Binary {
551            op: BinaryOp::Add,
552            lhs: Box::new(Evaluated::FunctionCall {
553                function: UnaryFunction::Sinh,
554                arg: Box::new(Evaluated::Variable("z".to_owned())),
555            }),
556            rhs: Box::new(Evaluated::Value(Complex32::new(0.5, -0.2))),
557        };
558        assert_eq!(function.assignments, vec![(String::new(), expected_expr)]);
559    }
560
561    #[test]
562    fn function_with_assignments() {
563        let function: Function = "c = 0.5 - 0.2i; z*z + c".parse().unwrap();
564        let expected_expr = Evaluated::Binary {
565            op: BinaryOp::Add,
566            lhs: z_square(),
567            rhs: Box::new(Evaluated::Variable("c".to_owned())),
568        };
569
570        assert_eq!(
571            function.assignments,
572            vec![
573                ("c".to_owned(), Evaluated::Value(Complex32::new(0.5, -0.2))),
574                (String::new(), expected_expr),
575            ]
576        );
577    }
578}