precise_calc/
ast.rs

1//! Contains all the enums used to represent the AST.
2
3use std::collections::HashMap;
4use std::fmt::Debug;
5use std::iter::zip;
6
7use serde::{Deserialize, Serialize};
8
9use crate::context::Context;
10use crate::eval::eval_expr;
11use crate::{CalcError, CalcResult, Number};
12
13/// Holds a reference to a builtin or user-defined function.
14pub enum CalcFuncRef<'a> {
15    /// Builtin function (like log or sqrt).
16    Builtin(&'a BuiltinFunc),
17
18    /// User-defined function.
19    UserDef(&'a UserFunc),
20}
21
22impl<'a> CalcFuncRef<'a> {
23    /// Call the function with arguments `args` in context `ctx`.
24    pub fn call(&self, args: &[Number], ctx: &Context) -> CalcResult {
25        match self {
26            CalcFuncRef::Builtin(func) => {
27                // First, check arity
28                if func.arity != args.len() {
29                    return Err(CalcError::IncorrectArity(func.arity, args.len()));
30                }
31
32                // Call function
33                func.apply(args, ctx)
34            }
35            CalcFuncRef::UserDef(func) => {
36                // Check arity
37                if func.arity() != args.len() {
38                    return Err(CalcError::IncorrectArity(func.arity(), args.len()));
39                }
40
41                func.apply(args, ctx)
42            }
43        }
44    }
45}
46
47/// A function that can be called by the user.
48#[derive(Clone)]
49pub enum CalcFunc {
50    /// Builtin function (like log or sqrt).
51    Builtin(BuiltinFunc),
52
53    /// User-defined function.
54    UserDef(UserFunc),
55}
56
57/// Builtin function (like log or sqrt).
58#[derive(Clone)]
59pub struct BuiltinFunc {
60    /// The number of arguments that the function takes.
61    pub arity: usize,
62
63    // TODO: Probably doesn't need context.
64    apply: fn(&[Number], &Context) -> CalcResult,
65}
66
67impl BuiltinFunc {
68    pub(crate) fn new(arity: usize, apply: fn(&[Number], &Context) -> CalcResult) -> BuiltinFunc {
69        BuiltinFunc { arity, apply }
70    }
71
72    /// Call the function on `args` in `ctx`.
73    pub fn apply(&self, args: &[Number], ctx: &Context) -> CalcResult {
74        (self.apply)(args, ctx)
75    }
76}
77
78impl Debug for BuiltinFunc {
79    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
80        write!(f, "builtin function")
81    }
82}
83
84/// User-defined function.
85#[derive(Clone, Debug, Serialize, Deserialize)]
86pub struct UserFunc {
87    bindings: Vec<String>,
88    body: Expr,
89}
90
91impl UserFunc {
92    /// Create a new function with parameters `bindings` (in order) and body expression `body`.
93    pub fn new(bindings: Vec<String>, body: Expr) -> UserFunc {
94        UserFunc { bindings, body }
95    }
96
97    /// Call the function on `args` in `ctx`.
98    pub fn apply(&self, args: &[Number], ctx: &Context) -> CalcResult {
99        // Create evaluation scope
100        let mut eval_scope = ctx.clone();
101        let bindings = HashMap::from_iter(zip(self.bindings.iter().cloned(), args.iter().cloned()));
102        eval_scope.add_scope(bindings);
103
104        // Evaluate function body in new scope
105        eval_expr(&self.body, &eval_scope)
106    }
107
108    /// Return the number of arguments the function takes.
109    pub fn arity(&self) -> usize {
110        self.bindings.len()
111    }
112}
113
114#[allow(missing_docs)]
115#[derive(Clone, Debug, Serialize, Deserialize)]
116pub enum Atom {
117    Symbol(String),
118    Num(Number),
119}
120
121#[allow(missing_docs)]
122#[derive(Clone, Debug, Serialize, Deserialize)]
123pub enum BinaryOp {
124    Plus,
125    Minus,
126    Times,
127    Divide,
128    Power,
129}
130
131#[allow(missing_docs)]
132#[derive(Clone, Debug, Serialize, Deserialize)]
133pub enum UnaryOp {
134    Negate,
135}
136
137#[allow(missing_docs)]
138#[derive(Clone, Debug, Serialize, Deserialize)]
139pub enum Expr {
140    AtomExpr(Atom),
141    UnaryExpr {
142        op: UnaryOp,
143        data: Box<Expr>,
144    },
145    BinaryExpr {
146        lhs: Box<Expr>,
147        rhs: Box<Expr>,
148        op: BinaryOp,
149    },
150    FunctionCall {
151        function: String,
152        args: Vec<Expr>,
153    },
154}
155
156#[allow(missing_docs)]
157#[derive(Clone, Debug, Serialize, Deserialize)]
158pub enum Stmt {
159    FuncDef {
160        name: String,
161        params: Vec<String>,
162        body: Expr,
163    },
164    Assignment {
165        name: String,
166        value: Expr,
167    },
168    ExprStmt(Expr),
169}