1use super::macros::{ExpressionMacroInvocation, MacroDefinition};
2use num_bigint::BigInt;
3use snafu::OptionExt;
4use snafu::{Backtrace, Snafu};
5use std::collections::HashMap;
6use std::fmt::{self, Debug};
7
8#[derive(Snafu, Debug)]
10#[snafu(context(suffix(false)), visibility(pub))]
11pub enum Error {
12 #[snafu(display("unknown label `{}`", label))]
13 #[non_exhaustive]
14 UnknownLabel { label: String, backtrace: Backtrace },
15
16 #[snafu(display("unknown macro `{}`", name))]
17 #[non_exhaustive]
18 UnknownMacro { name: String, backtrace: Backtrace },
19
20 #[snafu(display("undefined macro variable `{}`", name))]
21 #[non_exhaustive]
22 UndefinedVariable { name: String, backtrace: Backtrace },
23}
24
25type LabelsMap = HashMap<String, Option<usize>>;
26type VariablesMap = HashMap<String, Expression>;
27type MacrosMap = HashMap<String, MacroDefinition>;
28
29#[derive(Clone, Copy, Debug, Default)]
31pub struct Context<'a> {
32 labels: Option<&'a LabelsMap>,
33 macros: Option<&'a MacrosMap>,
34 variables: Option<&'a VariablesMap>,
35}
36
37impl<'a> Context<'a> {
38 pub fn get_label(&self, key: &str) -> Option<&Option<usize>> {
40 match self.labels {
41 Some(labels) => labels.get(key),
42 None => None,
43 }
44 }
45
46 pub fn get_macro(&self, key: &str) -> Option<&MacroDefinition> {
48 match self.macros {
49 Some(macros) => macros.get(key),
50 None => None,
51 }
52 }
53
54 pub fn get_variable(&self, key: &str) -> Option<&Expression> {
56 match self.variables {
57 Some(variables) => variables.get(key),
58 None => None,
59 }
60 }
61}
62
63impl<'a> From<&'a LabelsMap> for Context<'a> {
64 fn from(labels: &'a LabelsMap) -> Self {
65 Self {
66 labels: Some(labels),
67 macros: None,
68 variables: None,
69 }
70 }
71}
72
73impl<'a> From<(&'a LabelsMap, &'a MacrosMap)> for Context<'a> {
74 fn from(x: (&'a LabelsMap, &'a MacrosMap)) -> Self {
75 Self {
76 labels: Some(x.0),
77 macros: Some(x.1),
78 variables: None,
79 }
80 }
81}
82
83impl<'a> From<(&'a LabelsMap, &'a MacrosMap, &'a VariablesMap)> for Context<'a> {
84 fn from(x: (&'a LabelsMap, &'a MacrosMap, &'a VariablesMap)) -> Self {
85 Self {
86 labels: Some(x.0),
87 macros: Some(x.1),
88 variables: Some(x.2),
89 }
90 }
91}
92
93#[derive(Clone, Eq, Hash, Ord, PartialEq, PartialOrd)]
95pub enum Expression {
96 Expression(Box<Self>),
98
99 Macro(ExpressionMacroInvocation),
101
102 Terminal(Terminal),
104
105 Plus(Box<Self>, Box<Self>),
107
108 Minus(Box<Self>, Box<Self>),
110
111 Times(Box<Self>, Box<Self>),
113
114 Divide(Box<Self>, Box<Self>),
116}
117
118impl Debug for Expression {
119 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
120 match self {
121 Expression::Expression(s) => write!(f, r#"({:?})"#, s),
122 Expression::Macro(m) => write!(f, r#"Expression::Macro("{}")"#, m.name),
123 Expression::Terminal(t) => write!(f, r#"Expression::Terminal({:?})"#, t),
124 Expression::Plus(lhs, rhs) => write!(f, r#"Expression::Plus({:?}, {:?})"#, lhs, rhs),
125 Expression::Minus(lhs, rhs) => write!(f, r#"Expression::Minus({:?}, {:?})"#, lhs, rhs),
126 Expression::Times(lhs, rhs) => write!(f, r#"Expression::Times({:?}, {:?})"#, lhs, rhs),
127 Expression::Divide(lhs, rhs) => {
128 write!(f, r#"Expression::Divide({:?}, {:?})"#, lhs, rhs)
129 }
130 }
131 }
132}
133
134impl fmt::Display for Expression {
135 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
136 match self {
137 Expression::Expression(s) => write!(f, r#"({})"#, s),
138 Expression::Macro(m) => write!(f, r#"{}"#, m),
139 Expression::Terminal(t) => write!(f, r#"{}"#, t),
140 Expression::Plus(lhs, rhs) => write!(f, r#"{}+{}"#, lhs, rhs),
141 Expression::Minus(lhs, rhs) => write!(f, r#"{}-{}"#, lhs, rhs),
142 Expression::Times(lhs, rhs) => write!(f, r#"{}*{}"#, lhs, rhs),
143 Expression::Divide(lhs, rhs) => write!(f, r#"{}/{}"#, lhs, rhs),
144 }
145 }
146}
147
148#[derive(Clone, Eq, Hash, Ord, PartialEq, PartialOrd)]
150pub enum Terminal {
151 Number(BigInt),
153
154 Label(String),
156
157 Variable(String),
159}
160
161impl Terminal {
162 pub fn eval(&self) -> Result<BigInt, Error> {
164 self.eval_with_context(Context::default())
165 }
166
167 pub fn eval_with_context(&self, ctx: Context) -> Result<BigInt, Error> {
169 let ret = match self {
170 Terminal::Number(n) => n.clone(),
171 Terminal::Label(label) => ctx
172 .get_label(label)
173 .context(UnknownLabel { label })?
174 .context(UnknownLabel { label })?
175 .into(),
176 Terminal::Variable(name) => ctx
177 .get_variable(name)
178 .context(UndefinedVariable { name })?
179 .eval_with_context(ctx)?,
180 };
181
182 Ok(ret)
183 }
184}
185
186impl Expression {
187 pub fn eval(&self) -> Result<BigInt, Error> {
189 self.eval_with_context(Context::default())
190 }
191
192 pub fn eval_with_context(&self, ctx: Context) -> Result<BigInt, Error> {
194 fn eval(e: &Expression, ctx: Context) -> Result<BigInt, Error> {
195 let ret = match e {
196 Expression::Expression(expr) => eval(expr, ctx)?,
197 Expression::Macro(invc) => {
198 let defn = ctx.get_macro(&invc.name).context(UnknownMacro {
199 name: invc.name.clone(),
200 })?;
201
202 let vars = defn
203 .parameters()
204 .iter()
205 .cloned()
206 .zip(invc.parameters.iter().cloned())
207 .collect();
208
209 let mut ctx = ctx;
210 ctx.variables = Some(&vars);
211
212 defn.unwrap_expression()
213 .content
214 .tree
215 .eval_with_context(ctx)?
216 }
217 Expression::Terminal(term) => term.eval_with_context(ctx)?,
218 Expression::Plus(lhs, rhs) => eval(lhs, ctx)? + eval(rhs, ctx)?,
219 Expression::Minus(lhs, rhs) => eval(lhs, ctx)? - eval(rhs, ctx)?,
220 Expression::Times(lhs, rhs) => eval(lhs, ctx)? * eval(rhs, ctx)?,
221 Expression::Divide(lhs, rhs) => eval(lhs, ctx)? / eval(rhs, ctx)?,
222 };
223
224 Ok(ret)
225 }
226
227 eval(self, ctx)
229 }
230
231 pub fn labels(&self, macros: &MacrosMap) -> Result<Vec<String>, Error> {
233 fn dfs(x: &Expression, m: &MacrosMap) -> Result<Vec<String>, Error> {
234 match x {
235 Expression::Expression(e) => dfs(e, m),
236 Expression::Macro(macro_invocation) => m
237 .get(¯o_invocation.name)
238 .context(UnknownMacro {
239 name: macro_invocation.name.clone(),
240 })?
241 .unwrap_expression()
242 .content
243 .tree
244 .labels(m),
245 Expression::Terminal(Terminal::Label(label)) => Ok(vec![label.clone()]),
246 Expression::Terminal(_) => Ok(vec![]),
247 Expression::Plus(lhs, rhs)
248 | Expression::Minus(lhs, rhs)
249 | Expression::Times(lhs, rhs)
250 | Expression::Divide(lhs, rhs) => dfs(lhs, m).and_then(|x: Vec<String>| {
251 let ret = x.into_iter().chain(dfs(rhs, m)?).collect();
252 Ok(ret)
253 }),
254 }
255 }
256
257 dfs(self, macros)
258 }
259
260 pub fn replace_label(&mut self, old: &str, new: &str) {
262 fn dfs(x: &mut Expression, old: &str, new: &str) {
263 match x {
264 Expression::Expression(e) => dfs(e, new, old),
265 Expression::Terminal(Terminal::Label(ref mut label)) => {
266 if *label == old {
267 *label = new.to_string();
268 }
269 }
270 Expression::Plus(lhs, rhs)
271 | Expression::Minus(lhs, rhs)
272 | Expression::Times(lhs, rhs)
273 | Expression::Divide(lhs, rhs) => {
274 dfs(lhs, new, old);
275 dfs(rhs, new, old);
276 }
277 Expression::Macro(_) | Expression::Terminal(_) => (),
278 }
279 }
280
281 dfs(self, old, new)
282 }
283
284 pub fn fill_variable(&mut self, var: &str, expr: &Expression) {
286 fn dfs(x: &mut Expression, var: &str, expr: &Expression) {
287 match x {
288 Expression::Terminal(Terminal::Variable(name)) => {
289 if var == name {
290 *x = expr.clone();
291 }
292 }
293 Expression::Expression(e) => dfs(e, var, expr),
294 Expression::Plus(lhs, rhs)
295 | Expression::Minus(lhs, rhs)
296 | Expression::Times(lhs, rhs)
297 | Expression::Divide(lhs, rhs) => {
298 dfs(lhs, var, expr);
299 dfs(rhs, var, expr);
300 }
301 Expression::Macro(_) | Expression::Terminal(_) => (),
302 }
303 }
304
305 dfs(self, var, expr)
306 }
307}
308
309impl Debug for Terminal {
310 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
311 match self {
312 Terminal::Label(l) => write!(f, r#"Terminal::Label({})"#, l),
313 Terminal::Number(n) => write!(f, r#"Terminal::Number({})"#, n),
314 Terminal::Variable(v) => write!(f, r#"Terminal::Variable({})"#, v),
315 }
316 }
317}
318
319impl fmt::Display for Terminal {
320 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
321 match self {
322 Terminal::Label(l) => write!(f, r#"Label({})"#, l),
323 Terminal::Number(n) => write!(f, r#"{}"#, n),
324 Terminal::Variable(v) => write!(f, r#"Variable({})"#, v),
325 }
326 }
327}
328
329impl From<Terminal> for Expression {
330 fn from(terminal: Terminal) -> Self {
331 Expression::Terminal(terminal)
332 }
333}
334
335impl From<Terminal> for Box<Expression> {
336 fn from(terminal: Terminal) -> Self {
337 Box::new(Expression::Terminal(terminal))
338 }
339}
340
341impl From<u64> for Box<Expression> {
342 fn from(n: u64) -> Self {
343 Box::new(Expression::Terminal(Terminal::Number(n.into())))
344 }
345}
346
347impl From<u64> for Terminal {
348 fn from(n: u64) -> Self {
349 Terminal::Number(n.into())
350 }
351}
352
353impl From<BigInt> for Box<Expression> {
354 fn from(n: BigInt) -> Self {
355 Box::new(n.into())
356 }
357}
358
359impl From<BigInt> for Expression {
360 fn from(n: BigInt) -> Self {
361 Expression::Terminal(Terminal::Number(n))
362 }
363}
364
365#[cfg(test)]
366mod tests {
367 use super::*;
368 use assert_matches::assert_matches;
369
370 #[test]
371 fn expr_simple() {
372 let expr = Expression::Plus(24.into(), 42.into());
374 let out = expr.eval().unwrap();
375 assert_eq!(out, BigInt::from(66));
376 }
377
378 #[test]
379 fn expr_nested() {
380 let expr = Expression::Minus(
382 Expression::Times(Expression::Plus(1.into(), 2.into()).into(), 3.into()).into(),
383 Expression::Divide(4.into(), 2.into()).into(),
384 );
385 let out = expr.eval().unwrap();
386 assert_eq!(out, BigInt::from(7));
387 }
388
389 #[test]
390 fn expr_with_label() {
391 let expr = Expression::Plus(Terminal::Label(String::from("foo")).into(), 1.into());
393 let labels: HashMap<_, _> = vec![("foo".to_string(), Some(41))].into_iter().collect();
394 let out = expr.eval_with_context(Context::from(&labels)).unwrap();
395 assert_eq!(out, BigInt::from(42));
396 }
397
398 #[test]
399 fn expr_unknown_label() {
400 let expr = Expression::Plus(Terminal::Label(String::from("foo")).into(), 1.into());
402 let err = expr.eval().unwrap_err();
403 assert_matches!(err, Error::UnknownLabel { label, .. } if label == "foo");
404
405 let expr = Expression::Plus(Terminal::Label(String::from("foo")).into(), 1.into());
407 let labels: HashMap<_, _> = vec![("foo".to_string(), None)].into_iter().collect();
408 let err = expr.eval_with_context(Context::from(&labels)).unwrap_err();
409 assert_matches!(err, Error::UnknownLabel { label, .. } if label == "foo");
410 }
411}