momba_explore/explore/
evaluate.rs

1use std::cmp::max;
2use std::collections::HashMap;
3use std::convert::TryInto;
4
5use indexmap::IndexSet;
6
7use super::model::*;
8
9/// Represents an evaluation environment.
10#[derive(Debug)]
11pub struct Environment<'r, const BANKS: usize> {
12    pub(crate) banks: [&'r [Value]; BANKS],
13}
14
15/// Represents a register address.
16#[derive(Debug)]
17pub struct RegisterAddress {
18    bank: usize,
19    register: usize,
20}
21
22impl<'r, const BANKS: usize> Environment<'r, BANKS> {
23    /// Creates a new evaluation environment from the given register banks.
24    pub fn new(banks: [&'r [Value]; BANKS]) -> Self {
25        Environment { banks }
26    }
27
28    /// Returns the value stored at the given address.
29    pub fn get_value(&self, address: &RegisterAddress) -> &Value {
30        &self.banks[address.bank][address.register]
31    }
32}
33
34/// Represents a compiled expression.
35pub struct CompiledExpression<const BANKS: usize> {
36    closure: Box<dyn Send + Sync + Fn(&Environment<BANKS>, &mut Vec<Value>) -> Value>,
37    stack_depth: usize,
38}
39
40impl<const BANKS: usize> CompiledExpression<BANKS> {
41    fn new(
42        closure: impl 'static + Send + Sync + Fn(&Environment<BANKS>, &mut Vec<Value>) -> Value,
43        stack_depth: usize,
44    ) -> Self {
45        CompiledExpression {
46            closure: Box::new(closure),
47            stack_depth,
48        }
49    }
50
51    fn evaluate_with_stack(&self, env: &Environment<BANKS>, stack: &mut Vec<Value>) -> Value {
52        (self.closure)(env, stack)
53    }
54
55    pub fn evaluate(&self, env: &Environment<BANKS>) -> Value {
56        self.evaluate_with_stack(env, &mut Vec::with_capacity(self.stack_depth))
57    }
58}
59
60/// Represents an assignment target.
61pub struct Target<'e> {
62    store: &'e mut [Value],
63    index: usize,
64}
65
66impl<'e> Target<'e> {
67    pub fn store(&mut self, value: Value) {
68        self.store[self.index] = value
69    }
70
71    pub fn resolve(self) -> &'e mut Value {
72        &mut self.store[self.index]
73    }
74}
75
76/// Represents a compiled assignment target.
77pub struct CompiledTargetExpression<const BANKS: usize> {
78    closure: Box<
79        dyn Send
80            + Sync
81            + for<'e> Fn(&'e mut [Value], &Environment<BANKS>, &mut Vec<Value>) -> Target<'e>,
82    >,
83    stack_depth: usize,
84}
85
86impl<const BANKS: usize> CompiledTargetExpression<BANKS> {
87    fn new(
88        closure: impl 'static
89            + Send
90            + Sync
91            + for<'e> Fn(&'e mut [Value], &Environment<BANKS>, &mut Vec<Value>) -> Target<'e>,
92        stack_depth: usize,
93    ) -> Self {
94        CompiledTargetExpression {
95            closure: Box::new(closure),
96            stack_depth,
97        }
98    }
99
100    fn evaluate_with_stack<'e>(
101        &self,
102        targets: &'e mut [Value],
103        env: &Environment<BANKS>,
104        stack: &mut Vec<Value>,
105    ) -> Target<'e> {
106        (self.closure)(targets, env, stack)
107    }
108
109    pub fn evaluate<'e>(&self, targets: &'e mut [Value], env: &Environment<BANKS>) -> Target<'e> {
110        self.evaluate_with_stack(targets, env, &mut Vec::with_capacity(self.stack_depth))
111    }
112}
113
114pub trait CompileBackend<C> {
115    fn compile_name(&self, identifier: &str) -> C;
116
117    fn compile_with_context(&self, expression: &Expression, ctx: &mut CompileContext) -> C;
118
119    fn compile(&self, expression: &Expression) -> C;
120}
121
122#[derive(Clone)]
123pub struct CompileContext {
124    max_stack_depth: usize,
125    stack_variables: IndexSet<String>,
126}
127
128impl CompileContext {
129    fn new() -> Self {
130        CompileContext {
131            max_stack_depth: 0,
132            stack_variables: IndexSet::new(),
133        }
134    }
135
136    fn push_stack_variable(&mut self, identifier: String) {
137        self.stack_variables.insert(identifier);
138        if self.stack_variables.len() > self.max_stack_depth {
139            self.max_stack_depth = self.stack_variables.len()
140        }
141    }
142
143    fn pop_stack_variable(&mut self) {
144        self.stack_variables.pop();
145    }
146
147    fn get_stack_index(&self, identifier: &str) -> Option<usize> {
148        self.stack_variables.get_index_of(identifier)
149    }
150}
151
152pub struct Scope<const BANKS: usize> {
153    banks: [HashMap<String, usize>; BANKS],
154}
155
156impl<const BANKS: usize> Scope<BANKS> {
157    pub fn get_address(&self, identifier: &str) -> Option<RegisterAddress> {
158        self.banks
159            .iter()
160            .enumerate()
161            .rev()
162            .filter_map(|(bank, identifiers)| {
163                identifiers.get(identifier).map(|register| RegisterAddress {
164                    bank,
165                    register: *register,
166                })
167            })
168            .next()
169    }
170
171    fn compile_with_context(
172        &self,
173        expression: &Expression,
174        ctx: &mut CompileContext,
175    ) -> CompiledExpression<BANKS> {
176        macro_rules! compile {
177            ($expr:expr) => {
178                self.compile_with_context($expr, ctx)
179            };
180            ($expr:expr; push stack $var:expr) => {{
181                ctx.push_stack_variable($var.into());
182                let compiled = self.compile_with_context($expr, ctx);
183                ctx.pop_stack_variable();
184                compiled
185            }};
186        }
187
188        macro_rules! evaluate {
189            ($expr:expr, $env:expr, $stack:expr) => {
190                $expr.evaluate_with_stack($env, $stack)
191            };
192            ($expr:expr, $env:expr, $stack:expr; push stack $val:expr) => {{
193                $stack.push($val.into());
194                let result = $expr.evaluate_with_stack($env, $stack);
195                $stack.pop();
196                result
197            }};
198        }
199
200        macro_rules! construct {
201            ($closure:expr) => {
202                CompiledExpression::new($closure, ctx.max_stack_depth)
203            };
204        }
205
206        match expression {
207            Expression::Name(NameExpression { identifier }) => {
208                ctx.get_stack_index(identifier).map_or_else(
209                    || {
210                        let address = self.get_address(identifier).unwrap();
211                        construct!(move |env, _| env.get_value(&address).clone())
212                    },
213                    |index| construct!(move |_, stack| stack[index].clone()),
214                )
215            }
216            Expression::Constant(ConstantExpression { value }) => {
217                let value = value.clone();
218                construct!(move |_, _| value.clone())
219            }
220            Expression::Unary(UnaryExpression { operator, operand }) => {
221                let operand = compile!(operand);
222
223                macro_rules! compile_unary {
224                    ($function:ident) => {
225                        construct!(move |env, stack| evaluate!(operand, env, stack).$function())
226                    };
227                }
228
229                match operator {
230                    UnaryOperator::Not => compile_unary!(apply_not),
231                    UnaryOperator::Minus => compile_unary!(apply_minus),
232                    UnaryOperator::Floor => compile_unary!(apply_floor),
233                    UnaryOperator::Ceil => compile_unary!(apply_ceil),
234                    UnaryOperator::Abs => compile_unary!(apply_abs),
235                    UnaryOperator::Sgn => compile_unary!(apply_sgn),
236                    UnaryOperator::Trc => compile_unary!(apply_trc),
237                }
238            }
239            Expression::Binary(BinaryExpression {
240                operator,
241                left,
242                right,
243            }) => {
244                let left = compile!(left);
245                let right = compile!(right);
246
247                macro_rules! compile_binary {
248                    ($function:ident) => {
249                        construct!(move |env, stack| evaluate!(left, env, stack)
250                            .$function(evaluate!(right, env, stack)))
251                    };
252                }
253
254                match operator {
255                    BinaryOperator::Add => compile_binary!(apply_add),
256                    BinaryOperator::Sub => compile_binary!(apply_sub),
257                    BinaryOperator::Mul => compile_binary!(apply_mul),
258                    BinaryOperator::FloorDiv => compile_binary!(apply_floor_div),
259                    BinaryOperator::RealDiv => compile_binary!(apply_real_div),
260                    BinaryOperator::Mod => compile_binary!(apply_mod),
261                    BinaryOperator::Log => compile_binary!(apply_log),
262                    BinaryOperator::Pow => compile_binary!(apply_pow),
263                    BinaryOperator::Min => compile_binary!(apply_min),
264                    BinaryOperator::Max => compile_binary!(apply_max),
265                }
266            }
267            Expression::Comparison(ComparisonExpression {
268                operator,
269                left,
270                right,
271            }) => {
272                let left = compile!(left);
273                let right = compile!(right);
274
275                macro_rules! compile_comparison {
276                    ($function:ident) => {
277                        construct!(move |env, stack| evaluate!(left, env, stack)
278                            .$function(evaluate!(right, env, stack)))
279                    };
280                }
281
282                match operator {
283                    ComparisonOperator::Eq => compile_comparison!(apply_cmp_eq),
284                    ComparisonOperator::Ne => compile_comparison!(apply_cmp_ne),
285                    ComparisonOperator::Lt => compile_comparison!(apply_cmp_lt),
286                    ComparisonOperator::Le => compile_comparison!(apply_cmp_le),
287                    ComparisonOperator::Ge => compile_comparison!(apply_cmp_ge),
288                    ComparisonOperator::Gt => compile_comparison!(apply_cmp_gt),
289                }
290            }
291            Expression::Boolean(BooleanExpression { operator, operands }) => {
292                let operands: Box<[_]> = operands.iter().map(|operand| compile!(operand)).collect();
293
294                macro_rules! compile_boolean {
295                    ($function:ident) => {
296                        construct!(move |env, stack| {
297                            operands
298                                .iter()
299                                .$function(|operand| {
300                                    evaluate!(operand, env, stack).try_into().unwrap()
301                                })
302                                .into()
303                        })
304                    };
305                }
306
307                match operator {
308                    BooleanOperator::And => compile_boolean!(all),
309                    BooleanOperator::Or => compile_boolean!(any),
310                }
311            }
312            Expression::Comprehension(ComprehensionExpression {
313                variable,
314                length,
315                element,
316            }) => {
317                let length = compile!(length);
318                let element = compile!(element; push stack variable);
319
320                construct!(move |env, stack| {
321                    let length = evaluate!(length, env, stack).try_into().unwrap();
322                    Value::Vector(
323                        (0..length)
324                            .map(|index| evaluate!(element, env, stack; push stack index))
325                            .collect(),
326                    )
327                })
328            }
329            Expression::Conditional(ConditionalExpression {
330                condition,
331                consequence,
332                alternative,
333            }) => {
334                let condition = compile!(condition);
335                let consequence = compile!(consequence);
336                let alternative = compile!(alternative);
337
338                construct!(move |env, stack| {
339                    if evaluate!(condition, env, stack).try_into().unwrap() {
340                        evaluate!(consequence, env, stack)
341                    } else {
342                        evaluate!(alternative, env, stack)
343                    }
344                })
345            }
346            Expression::Vector(VectorExpression { elements }) => {
347                let elements: Vec<_> = elements.iter().map(|element| compile!(element)).collect();
348
349                construct!(move |env, stack| {
350                    Value::Vector(
351                        elements
352                            .iter()
353                            .map(|element| evaluate!(element, env, stack))
354                            .collect(),
355                    )
356                })
357            }
358            Expression::Trigonometric(TrigonometricExpression{ function, operand}) => {
359                let operand = compile!(operand);
360
361                macro_rules! compile_trigonometric {
362                    ($function:ident) => {
363                        construct!(move |env, stack| evaluate!(operand, env, stack).$function())
364                    };
365                }
366
367                match function {
368                    TrigonometricFunction::Sin => compile_trigonometric!(apply_sin),
369                    TrigonometricFunction::Cos => compile_trigonometric!(apply_cos),
370                    TrigonometricFunction::Tan => compile_trigonometric!(apply_tan),
371                    _ => panic!("trigonometric function {:?} not implemented", function),
372                }
373            }
374            _ => panic!("not implemented {:?}", expression),
375        }
376    }
377
378    pub fn compile(&self, expression: &Expression) -> CompiledExpression<BANKS> {
379        self.compile_with_context(expression, &mut CompileContext::new())
380    }
381
382    pub fn compile_target(&self, expression: &Expression) -> CompiledTargetExpression<BANKS> {
383        match expression {
384            Expression::Name(NameExpression { identifier }) => {
385                let address = self.get_address(identifier).unwrap();
386                // Is the identifier a global variable?
387                let index = address.register;
388                CompiledTargetExpression::new(
389                    move |targets, _, _| match &mut targets[address.bank] {
390                        Value::Vector(vector) => Target {
391                            store: vector,
392                            index: index,
393                        },
394                        _ => panic!("Expected vector got."),
395                    },
396                    0,
397                )
398            }
399            Expression::Index(IndexExpression { vector, index }) => {
400                let vector = self.compile_target(vector);
401                let index = self.compile(index);
402                let stack_depth = max(vector.stack_depth, index.stack_depth);
403                CompiledTargetExpression::new(
404                    move |targets, env, stack| {
405                        let index = index.evaluate_with_stack(env, stack);
406                        let vector = vector.evaluate_with_stack(targets, env, stack);
407                        match (vector.resolve(), index) {
408                            (Value::Vector(vector), Value::Int64(index)) => Target {
409                                store: vector,
410                                index: index as usize,
411                            },
412                            tuple => {
413                                panic!("Unable to construct assignment target from {:?}.", tuple)
414                            }
415                        }
416                    },
417                    stack_depth,
418                )
419            }
420            _ => panic!("Unable to compile target from expression {:?}.", expression),
421        }
422    }
423}
424
425impl Network {
426    pub fn global_scope(&self) -> Scope<2> {
427        Scope {
428            banks: [
429                self.declarations
430                    .global_variables
431                    .keys()
432                    .enumerate()
433                    .map(|(index, identifier)| (identifier.clone(), index))
434                    .collect(),
435                self.declarations
436                    .transient_variables
437                    .keys()
438                    .enumerate()
439                    .map(|(index, identifier)| (identifier.clone(), index))
440                    .collect(),
441            ],
442        }
443    }
444
445    pub fn transient_scope(&self) -> Scope<1> {
446        Scope {
447            banks: [self
448                .declarations
449                .global_variables
450                .keys()
451                .enumerate()
452                .map(|(index, identifier)| (identifier.clone(), index))
453                .collect()],
454        }
455    }
456}
457
458impl Edge {
459    pub fn edge_scope(&self, network: &Network, edge: &Edge) -> Scope<3> {
460        let global_scope = network.global_scope();
461        Scope {
462            banks: [
463                global_scope.banks[0].clone(),
464                global_scope.banks[1].clone(),
465                match &edge.pattern {
466                    ActionPattern::Silent => HashMap::new(),
467                    ActionPattern::Labeled(labeled) => labeled
468                        .arguments
469                        .iter()
470                        .enumerate()
471                        .filter_map(|(index, argument)| match argument {
472                            PatternArgument::Read { identifier } => {
473                                Some((identifier.clone(), index))
474                            }
475                            _ => None,
476                        })
477                        .collect(),
478                },
479            ],
480        }
481    }
482}