etk_asm/ops/
expression.rs

1use super::macros::{ExpressionMacroInvocation, MacroDefinition};
2use num_bigint::BigInt;
3use snafu::OptionExt;
4use snafu::{Backtrace, Snafu};
5use std::collections::HashMap;
6use std::fmt::{self, Debug};
7
8/// An error that arises when an expression cannot be evaluated.
9#[derive(Snafu, Debug)]
10#[snafu(context(suffix(false)), visibility(pub))]
11pub enum Error {
12    #[snafu(display("unknown label `{}`", label))]
13    #[non_exhaustive]
14    UnknownLabel { label: String, backtrace: Backtrace },
15
16    #[snafu(display("unknown macro `{}`", name))]
17    #[non_exhaustive]
18    UnknownMacro { name: String, backtrace: Backtrace },
19
20    #[snafu(display("undefined macro variable `{}`", name))]
21    #[non_exhaustive]
22    UndefinedVariable { name: String, backtrace: Backtrace },
23}
24
25type LabelsMap = HashMap<String, Option<usize>>;
26type VariablesMap = HashMap<String, Expression>;
27type MacrosMap = HashMap<String, MacroDefinition>;
28
29/// Evaluation context for `Expression`.
30#[derive(Clone, Copy, Debug, Default)]
31pub struct Context<'a> {
32    labels: Option<&'a LabelsMap>,
33    macros: Option<&'a MacrosMap>,
34    variables: Option<&'a VariablesMap>,
35}
36
37impl<'a> Context<'a> {
38    /// Looks up a label in the current context.
39    pub fn get_label(&self, key: &str) -> Option<&Option<usize>> {
40        match self.labels {
41            Some(labels) => labels.get(key),
42            None => None,
43        }
44    }
45
46    /// Looks up a macro in the current context.
47    pub fn get_macro(&self, key: &str) -> Option<&MacroDefinition> {
48        match self.macros {
49            Some(macros) => macros.get(key),
50            None => None,
51        }
52    }
53
54    /// Looks up a variable in the current context.
55    pub fn get_variable(&self, key: &str) -> Option<&Expression> {
56        match self.variables {
57            Some(variables) => variables.get(key),
58            None => None,
59        }
60    }
61}
62
63impl<'a> From<&'a LabelsMap> for Context<'a> {
64    fn from(labels: &'a LabelsMap) -> Self {
65        Self {
66            labels: Some(labels),
67            macros: None,
68            variables: None,
69        }
70    }
71}
72
73impl<'a> From<(&'a LabelsMap, &'a MacrosMap)> for Context<'a> {
74    fn from(x: (&'a LabelsMap, &'a MacrosMap)) -> Self {
75        Self {
76            labels: Some(x.0),
77            macros: Some(x.1),
78            variables: None,
79        }
80    }
81}
82
83impl<'a> From<(&'a LabelsMap, &'a MacrosMap, &'a VariablesMap)> for Context<'a> {
84    fn from(x: (&'a LabelsMap, &'a MacrosMap, &'a VariablesMap)) -> Self {
85        Self {
86            labels: Some(x.0),
87            macros: Some(x.1),
88            variables: Some(x.2),
89        }
90    }
91}
92
93/// A mathematical expression.
94#[derive(Clone, Eq, Hash, Ord, PartialEq, PartialOrd)]
95pub enum Expression {
96    /// A mathematical expression.
97    Expression(Box<Self>),
98
99    /// An expression macro invocation.
100    Macro(ExpressionMacroInvocation),
101
102    /// A terminal value.
103    Terminal(Terminal),
104
105    /// An addition operation.
106    Plus(Box<Self>, Box<Self>),
107
108    /// A subtraction operation.
109    Minus(Box<Self>, Box<Self>),
110
111    /// A multiplication operation.
112    Times(Box<Self>, Box<Self>),
113
114    /// A division operation.
115    Divide(Box<Self>, Box<Self>),
116}
117
118impl Debug for Expression {
119    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
120        match self {
121            Expression::Expression(s) => write!(f, r#"({:?})"#, s),
122            Expression::Macro(m) => write!(f, r#"Expression::Macro("{}")"#, m.name),
123            Expression::Terminal(t) => write!(f, r#"Expression::Terminal({:?})"#, t),
124            Expression::Plus(lhs, rhs) => write!(f, r#"Expression::Plus({:?}, {:?})"#, lhs, rhs),
125            Expression::Minus(lhs, rhs) => write!(f, r#"Expression::Minus({:?}, {:?})"#, lhs, rhs),
126            Expression::Times(lhs, rhs) => write!(f, r#"Expression::Times({:?}, {:?})"#, lhs, rhs),
127            Expression::Divide(lhs, rhs) => {
128                write!(f, r#"Expression::Divide({:?}, {:?})"#, lhs, rhs)
129            }
130        }
131    }
132}
133
134impl fmt::Display for Expression {
135    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
136        match self {
137            Expression::Expression(s) => write!(f, r#"({})"#, s),
138            Expression::Macro(m) => write!(f, r#"{}"#, m),
139            Expression::Terminal(t) => write!(f, r#"{}"#, t),
140            Expression::Plus(lhs, rhs) => write!(f, r#"{}+{}"#, lhs, rhs),
141            Expression::Minus(lhs, rhs) => write!(f, r#"{}-{}"#, lhs, rhs),
142            Expression::Times(lhs, rhs) => write!(f, r#"{}*{}"#, lhs, rhs),
143            Expression::Divide(lhs, rhs) => write!(f, r#"{}/{}"#, lhs, rhs),
144        }
145    }
146}
147
148/// A terminal value in an expression.
149#[derive(Clone, Eq, Hash, Ord, PartialEq, PartialOrd)]
150pub enum Terminal {
151    /// An integer value.
152    Number(BigInt),
153
154    /// A label.
155    Label(String),
156
157    /// A macro variable.
158    Variable(String),
159}
160
161impl Terminal {
162    /// Evaluates a terminal into an integer value.
163    pub fn eval(&self) -> Result<BigInt, Error> {
164        self.eval_with_context(Context::default())
165    }
166
167    /// Evaluates a terminal into an integer value, with a given given `Context`..
168    pub fn eval_with_context(&self, ctx: Context) -> Result<BigInt, Error> {
169        let ret = match self {
170            Terminal::Number(n) => n.clone(),
171            Terminal::Label(label) => ctx
172                .get_label(label)
173                .context(UnknownLabel { label })?
174                .context(UnknownLabel { label })?
175                .into(),
176            Terminal::Variable(name) => ctx
177                .get_variable(name)
178                .context(UndefinedVariable { name })?
179                .eval_with_context(ctx)?,
180        };
181
182        Ok(ret)
183    }
184}
185
186impl Expression {
187    /// Returns the constant value of the expression.
188    pub fn eval(&self) -> Result<BigInt, Error> {
189        self.eval_with_context(Context::default())
190    }
191
192    /// Evaluates the expression given a certain `Context`.
193    pub fn eval_with_context(&self, ctx: Context) -> Result<BigInt, Error> {
194        fn eval(e: &Expression, ctx: Context) -> Result<BigInt, Error> {
195            let ret = match e {
196                Expression::Expression(expr) => eval(expr, ctx)?,
197                Expression::Macro(invc) => {
198                    let defn = ctx.get_macro(&invc.name).context(UnknownMacro {
199                        name: invc.name.clone(),
200                    })?;
201
202                    let vars = defn
203                        .parameters()
204                        .iter()
205                        .cloned()
206                        .zip(invc.parameters.iter().cloned())
207                        .collect();
208
209                    let mut ctx = ctx;
210                    ctx.variables = Some(&vars);
211
212                    defn.unwrap_expression()
213                        .content
214                        .tree
215                        .eval_with_context(ctx)?
216                }
217                Expression::Terminal(term) => term.eval_with_context(ctx)?,
218                Expression::Plus(lhs, rhs) => eval(lhs, ctx)? + eval(rhs, ctx)?,
219                Expression::Minus(lhs, rhs) => eval(lhs, ctx)? - eval(rhs, ctx)?,
220                Expression::Times(lhs, rhs) => eval(lhs, ctx)? * eval(rhs, ctx)?,
221                Expression::Divide(lhs, rhs) => eval(lhs, ctx)? / eval(rhs, ctx)?,
222            };
223
224            Ok(ret)
225        }
226
227        // TODO error if top level receives negative value.
228        eval(self, ctx)
229    }
230
231    /// Returns a list of all labels used in the expression.
232    pub fn labels(&self, macros: &MacrosMap) -> Result<Vec<String>, Error> {
233        fn dfs(x: &Expression, m: &MacrosMap) -> Result<Vec<String>, Error> {
234            match x {
235                Expression::Expression(e) => dfs(e, m),
236                Expression::Macro(macro_invocation) => m
237                    .get(&macro_invocation.name)
238                    .context(UnknownMacro {
239                        name: macro_invocation.name.clone(),
240                    })?
241                    .unwrap_expression()
242                    .content
243                    .tree
244                    .labels(m),
245                Expression::Terminal(Terminal::Label(label)) => Ok(vec![label.clone()]),
246                Expression::Terminal(_) => Ok(vec![]),
247                Expression::Plus(lhs, rhs)
248                | Expression::Minus(lhs, rhs)
249                | Expression::Times(lhs, rhs)
250                | Expression::Divide(lhs, rhs) => dfs(lhs, m).and_then(|x: Vec<String>| {
251                    let ret = x.into_iter().chain(dfs(rhs, m)?).collect();
252                    Ok(ret)
253                }),
254            }
255        }
256
257        dfs(self, macros)
258    }
259
260    /// Replaces all instances of `old` with `new` in the expression.
261    pub fn replace_label(&mut self, old: &str, new: &str) {
262        fn dfs(x: &mut Expression, old: &str, new: &str) {
263            match x {
264                Expression::Expression(e) => dfs(e, new, old),
265                Expression::Terminal(Terminal::Label(ref mut label)) => {
266                    if *label == old {
267                        *label = new.to_string();
268                    }
269                }
270                Expression::Plus(lhs, rhs)
271                | Expression::Minus(lhs, rhs)
272                | Expression::Times(lhs, rhs)
273                | Expression::Divide(lhs, rhs) => {
274                    dfs(lhs, new, old);
275                    dfs(rhs, new, old);
276                }
277                Expression::Macro(_) | Expression::Terminal(_) => (),
278            }
279        }
280
281        dfs(self, old, new)
282    }
283
284    /// Replaces all instances of `var` with `expr` in the expression.
285    pub fn fill_variable(&mut self, var: &str, expr: &Expression) {
286        fn dfs(x: &mut Expression, var: &str, expr: &Expression) {
287            match x {
288                Expression::Terminal(Terminal::Variable(name)) => {
289                    if var == name {
290                        *x = expr.clone();
291                    }
292                }
293                Expression::Expression(e) => dfs(e, var, expr),
294                Expression::Plus(lhs, rhs)
295                | Expression::Minus(lhs, rhs)
296                | Expression::Times(lhs, rhs)
297                | Expression::Divide(lhs, rhs) => {
298                    dfs(lhs, var, expr);
299                    dfs(rhs, var, expr);
300                }
301                Expression::Macro(_) | Expression::Terminal(_) => (),
302            }
303        }
304
305        dfs(self, var, expr)
306    }
307}
308
309impl Debug for Terminal {
310    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
311        match self {
312            Terminal::Label(l) => write!(f, r#"Terminal::Label({})"#, l),
313            Terminal::Number(n) => write!(f, r#"Terminal::Number({})"#, n),
314            Terminal::Variable(v) => write!(f, r#"Terminal::Variable({})"#, v),
315        }
316    }
317}
318
319impl fmt::Display for Terminal {
320    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
321        match self {
322            Terminal::Label(l) => write!(f, r#"Label({})"#, l),
323            Terminal::Number(n) => write!(f, r#"{}"#, n),
324            Terminal::Variable(v) => write!(f, r#"Variable({})"#, v),
325        }
326    }
327}
328
329impl From<Terminal> for Expression {
330    fn from(terminal: Terminal) -> Self {
331        Expression::Terminal(terminal)
332    }
333}
334
335impl From<Terminal> for Box<Expression> {
336    fn from(terminal: Terminal) -> Self {
337        Box::new(Expression::Terminal(terminal))
338    }
339}
340
341impl From<u64> for Box<Expression> {
342    fn from(n: u64) -> Self {
343        Box::new(Expression::Terminal(Terminal::Number(n.into())))
344    }
345}
346
347impl From<u64> for Terminal {
348    fn from(n: u64) -> Self {
349        Terminal::Number(n.into())
350    }
351}
352
353impl From<BigInt> for Box<Expression> {
354    fn from(n: BigInt) -> Self {
355        Box::new(n.into())
356    }
357}
358
359impl From<BigInt> for Expression {
360    fn from(n: BigInt) -> Self {
361        Expression::Terminal(Terminal::Number(n))
362    }
363}
364
365#[cfg(test)]
366mod tests {
367    use super::*;
368    use assert_matches::assert_matches;
369
370    #[test]
371    fn expr_simple() {
372        // 24 + 42 = 66
373        let expr = Expression::Plus(24.into(), 42.into());
374        let out = expr.eval().unwrap();
375        assert_eq!(out, BigInt::from(66));
376    }
377
378    #[test]
379    fn expr_nested() {
380        //((1+2)*3-(4/2) = 7
381        let expr = Expression::Minus(
382            Expression::Times(Expression::Plus(1.into(), 2.into()).into(), 3.into()).into(),
383            Expression::Divide(4.into(), 2.into()).into(),
384        );
385        let out = expr.eval().unwrap();
386        assert_eq!(out, BigInt::from(7));
387    }
388
389    #[test]
390    fn expr_with_label() {
391        // foo + 1 = 42
392        let expr = Expression::Plus(Terminal::Label(String::from("foo")).into(), 1.into());
393        let labels: HashMap<_, _> = vec![("foo".to_string(), Some(41))].into_iter().collect();
394        let out = expr.eval_with_context(Context::from(&labels)).unwrap();
395        assert_eq!(out, BigInt::from(42));
396    }
397
398    #[test]
399    fn expr_unknown_label() {
400        // missing label
401        let expr = Expression::Plus(Terminal::Label(String::from("foo")).into(), 1.into());
402        let err = expr.eval().unwrap_err();
403        assert_matches!(err, Error::UnknownLabel { label, .. } if label == "foo");
404
405        // label w/o defined address
406        let expr = Expression::Plus(Terminal::Label(String::from("foo")).into(), 1.into());
407        let labels: HashMap<_, _> = vec![("foo".to_string(), None)].into_iter().collect();
408        let err = expr.eval_with_context(Context::from(&labels)).unwrap_err();
409        assert_matches!(err, Error::UnknownLabel { label, .. } if label == "foo");
410    }
411}