gluon_vm/
compiler.rs

1use std::ops::{Deref, DerefMut};
2
3use crate::base::{
4    ast::{DisplayEnv, Typed, TypedIdent},
5    kind::{ArcKind, KindEnv},
6    pos::Line,
7    resolve,
8    scoped_map::ScopedMap,
9    source::{FileMap, Source},
10    symbol::{Symbol, SymbolData, SymbolModule, SymbolRef},
11    types::{Alias, ArcType, BuiltinType, NullInterner, Type, TypeEnv, TypeExt},
12};
13
14use crate::{
15    core::{self, is_primitive, CExpr, Expr, Literal, Pattern},
16    interner::InternedStr,
17    source_map::{LocalMap, SourceMap},
18    types::*,
19    vm::GlobalVmState,
20    Error, Result,
21};
22
23use self::Variable::*;
24
25#[derive(Clone, Debug)]
26pub enum Variable<G> {
27    Stack(VmIndex),
28    Constructor(VmTag, VmIndex),
29    UpVar(G),
30}
31
32/// Field accesses on records can either be by name in the case of polymorphic records or by offset
33/// when the record is non-polymorphic (which is faster)
34enum FieldAccess {
35    Name,
36    Index(VmIndex),
37}
38
39#[derive(Debug, Default, PartialEq, Eq, Hash, Clone)]
40#[cfg_attr(feature = "serde_derive", derive(DeserializeState, SerializeState))]
41#[cfg_attr(
42    feature = "serde_derive",
43    serde(
44        deserialize_state = "crate::serialization::DeSeed<'gc>",
45        de_parameters = "'gc"
46    )
47)]
48#[cfg_attr(
49    feature = "serde_derive",
50    serde(serialize_state = "crate::serialization::SeSeed")
51)]
52pub struct UpvarInfo {
53    pub name: String,
54    #[cfg_attr(
55        feature = "serde_derive",
56        serde(state_with = "crate::serialization::borrow")
57    )]
58    pub typ: ArcType,
59}
60
61#[derive(Debug, Default, PartialEq, Eq, Hash, Clone)]
62#[cfg_attr(feature = "serde_derive", derive(DeserializeState, SerializeState))]
63#[cfg_attr(
64    feature = "serde_derive",
65    serde(
66        deserialize_state = "crate::serialization::DeSeed<'gc>",
67        de_parameters = "'gc"
68    )
69)]
70#[cfg_attr(
71    feature = "serde_derive",
72    serde(serialize_state = "crate::serialization::SeSeed")
73)]
74pub struct DebugInfo {
75    /// Maps instruction indexes to the line that spawned them
76    pub source_map: SourceMap,
77    #[cfg_attr(feature = "serde_derive", serde(state))]
78    pub local_map: LocalMap,
79    #[cfg_attr(feature = "serde_derive", serde(state))]
80    pub upvars: Vec<UpvarInfo>,
81    pub source_name: String,
82}
83
84#[derive(Debug, Eq, PartialEq, Hash)]
85#[cfg_attr(feature = "serde_derive", derive(DeserializeState, SerializeState))]
86#[cfg_attr(
87    feature = "serde_derive_state",
88    serde(
89        deserialize_state = "crate::serialization::DeSeed<'gc>",
90        de_parameters = "'gc"
91    )
92)]
93#[cfg_attr(
94    feature = "serde_derive_state",
95    serde(serialize_state = "crate::serialization::SeSeed")
96)]
97pub struct CompiledModule {
98    /// Storage for globals which are needed by the module which is currently being compiled
99    #[cfg_attr(
100        feature = "serde_derive",
101        serde(state_with = "crate::serialization::borrow")
102    )]
103    pub module_globals: Vec<Symbol>,
104    #[cfg_attr(
105        feature = "serde_derive",
106        serde(state_with = "crate::serialization::borrow")
107    )]
108    pub function: CompiledFunction,
109}
110
111#[derive(Debug, Eq, PartialEq, Hash)]
112#[cfg_attr(feature = "serde_derive", derive(DeserializeState, SerializeState))]
113#[cfg_attr(
114    feature = "serde_derive_state",
115    serde(
116        deserialize_state = "crate::serialization::DeSeed<'gc>",
117        de_parameters = "'gc"
118    )
119)]
120#[cfg_attr(
121    feature = "serde_derive_state",
122    serde(serialize_state = "crate::serialization::SeSeed")
123)]
124pub struct CompiledFunction {
125    pub args: VmIndex,
126    /// The maximum possible number of stack slots needed for this function
127    pub max_stack_size: VmIndex,
128
129    #[cfg_attr(
130        feature = "serde_derive",
131        serde(state_with = "crate::serialization::borrow")
132    )]
133    pub id: Symbol,
134
135    #[cfg_attr(
136        feature = "serde_derive",
137        serde(state_with = "crate::serialization::borrow")
138    )]
139    pub typ: ArcType,
140    pub instructions: Vec<Instruction>,
141
142    #[cfg_attr(feature = "serde_derive_state", serde(state))]
143    pub inner_functions: Vec<CompiledFunction>,
144
145    #[cfg_attr(feature = "serde_derive_state", serde(state))]
146    pub strings: Vec<InternedStr>,
147
148    #[cfg_attr(
149        feature = "serde_derive",
150        serde(state_with = "crate::serialization::borrow")
151    )]
152    pub records: Vec<Vec<Symbol>>,
153
154    #[cfg_attr(feature = "serde_derive_state", serde(state))]
155    pub debug_info: DebugInfo,
156}
157
158impl From<CompiledFunction> for CompiledModule {
159    fn from(function: CompiledFunction) -> Self {
160        CompiledModule {
161            module_globals: Vec::new(),
162            function,
163        }
164    }
165}
166
167impl CompiledFunction {
168    pub fn new(args: VmIndex, id: Symbol, typ: ArcType, source_name: String) -> CompiledFunction {
169        CompiledFunction {
170            args: args,
171            max_stack_size: 0,
172            id: id,
173            typ: typ,
174            instructions: Vec::new(),
175            inner_functions: Vec::new(),
176            strings: Vec::new(),
177            records: Vec::new(),
178            debug_info: DebugInfo {
179                source_map: SourceMap::new(),
180                local_map: LocalMap::new(),
181                upvars: Vec::new(),
182                source_name: source_name,
183            },
184        }
185    }
186}
187
188struct FunctionEnv {
189    /// The variables currently in scope in the this function.
190    stack: ScopedMap<Symbol, (VmIndex, ArcType)>,
191    /// The current size of the stack. Not the same as `stack.len()`.
192    /// The current size of the stack. Not the same as `stack.len()`.
193    stack_size: VmIndex,
194    /// The variables which this function takes from the outer scope
195    free_vars: Vec<(Symbol, ArcType)>,
196    /// The line where instructions are currently being emitted
197    current_line: Line,
198    emit_debug_info: bool,
199    function: CompiledFunction,
200}
201
202struct FunctionEnvs {
203    envs: Vec<FunctionEnv>,
204}
205
206impl Deref for FunctionEnvs {
207    type Target = FunctionEnv;
208    fn deref(&self) -> &FunctionEnv {
209        self.envs.last().expect("FunctionEnv")
210    }
211}
212
213impl DerefMut for FunctionEnvs {
214    fn deref_mut(&mut self) -> &mut FunctionEnv {
215        self.envs.last_mut().expect("FunctionEnv")
216    }
217}
218
219impl FunctionEnvs {
220    fn new() -> FunctionEnvs {
221        FunctionEnvs { envs: vec![] }
222    }
223
224    fn start_function(&mut self, compiler: &mut Compiler, args: VmIndex, id: Symbol, typ: ArcType) {
225        compiler.stack_types.enter_scope();
226        self.envs.push(FunctionEnv::new(
227            args,
228            id,
229            typ,
230            compiler.source_name.clone(),
231            compiler.emit_debug_info,
232        ));
233    }
234
235    fn end_function(&mut self, compiler: &mut Compiler, current_line: Option<Line>) -> FunctionEnv {
236        compiler.stack_types.exit_scope();
237        self.function.instructions.push(Instruction::Return);
238        let instructions = self.function.instructions.len();
239
240        if compiler.emit_debug_info {
241            self.function
242                .debug_info
243                .source_map
244                .close(instructions, current_line);
245
246            let upvars_are_globals = self.envs.len() == 1;
247            if !upvars_are_globals {
248                let function = &mut **self;
249                function
250                    .function
251                    .debug_info
252                    .upvars
253                    .extend(
254                        function
255                            .free_vars
256                            .iter()
257                            .map(|&(ref name, ref typ)| UpvarInfo {
258                                name: name.declared_name().to_string(),
259                                typ: typ.clone(),
260                            }),
261                    );
262            }
263        }
264
265        self.envs.pop().expect("FunctionEnv in scope")
266    }
267}
268
269impl FunctionEnv {
270    fn new(
271        args: VmIndex,
272        id: Symbol,
273        typ: ArcType,
274        source_name: String,
275        emit_debug_info: bool,
276    ) -> FunctionEnv {
277        FunctionEnv {
278            free_vars: Vec::new(),
279            stack: ScopedMap::new(),
280            stack_size: 0,
281            function: CompiledFunction::new(args, id, typ, source_name),
282            current_line: Line::from(0),
283            emit_debug_info,
284        }
285    }
286
287    fn emit(&mut self, instruction: Instruction) {
288        if let Slide(0) = instruction {
289            return;
290        }
291
292        let adjustment = instruction.adjust();
293        debug!("{:?} {} {}", instruction, self.stack_size, adjustment);
294        if adjustment > 0 {
295            self.increase_stack(adjustment as VmIndex);
296        } else {
297            self.stack_size -= -adjustment as VmIndex;
298        }
299
300        self.function.instructions.push(instruction);
301
302        if self.emit_debug_info {
303            self.function
304                .debug_info
305                .source_map
306                .emit(self.function.instructions.len() - 1, self.current_line);
307        }
308    }
309
310    fn increase_stack(&mut self, adjustment: VmIndex) {
311        use std::cmp::max;
312
313        self.stack_size += adjustment;
314        self.function.max_stack_size = max(self.function.max_stack_size, self.stack_size);
315    }
316
317    fn emit_call(&mut self, args: VmIndex, tail_position: bool) {
318        let i = if tail_position {
319            TailCall(args)
320        } else {
321            Call(args)
322        };
323        self.emit(i);
324    }
325
326    fn emit_field(&mut self, compiler: &mut Compiler, typ: &ArcType, field: &Symbol) -> Result<()> {
327        let field_index = compiler
328            .find_field(typ, field)
329            .expect("ICE: Undefined field in field access");
330        match field_index {
331            FieldAccess::Index(i) => self.emit(GetOffset(i)),
332            FieldAccess::Name => {
333                let interned = compiler.intern(field.as_ref())?;
334                let index = self.add_string_constant(interned);
335                self.emit(GetField(index));
336            }
337        }
338        Ok(())
339    }
340
341    fn add_record_map(&mut self, fields: Vec<Symbol>) -> VmIndex {
342        match self.function.records.iter().position(|t| *t == fields) {
343            Some(i) => i as VmIndex,
344            None => {
345                self.function.records.push(fields);
346                (self.function.records.len() - 1) as VmIndex
347            }
348        }
349    }
350
351    fn add_string_constant(&mut self, s: InternedStr) -> VmIndex {
352        match self.function.strings.iter().position(|t| *t == s) {
353            Some(i) => i as VmIndex,
354            None => {
355                self.function.strings.push(s);
356                (self.function.strings.len() - 1) as VmIndex
357            }
358        }
359    }
360
361    fn emit_string(&mut self, s: InternedStr) {
362        let index = self.add_string_constant(s);
363        self.emit(PushString(index as VmIndex));
364    }
365
366    fn upvar(&mut self, s: &Symbol, typ: &ArcType) -> VmIndex {
367        match self.free_vars.iter().position(|t| t.0 == *s) {
368            Some(index) => index as VmIndex,
369            None => {
370                self.free_vars.push((s.clone(), typ.clone()));
371                (self.free_vars.len() - 1) as VmIndex
372            }
373        }
374    }
375
376    fn stack_size(&mut self) -> VmIndex {
377        self.stack_size as VmIndex
378    }
379
380    fn push_stack_var(&mut self, compiler: &Compiler, s: Symbol, typ: ArcType) {
381        self.increase_stack(1);
382        self.new_stack_var(compiler, s, typ)
383    }
384
385    fn new_stack_var(&mut self, compiler: &Compiler, s: Symbol, typ: ArcType) {
386        debug!("Push var: {:?} at {}", s, self.stack_size - 1);
387        let index = self.stack_size - 1;
388        if self.emit_debug_info && compiler.empty_symbol != s {
389            self.function.debug_info.local_map.emit(
390                self.function.instructions.len(),
391                index,
392                s.clone(),
393                typ.clone(),
394            );
395        }
396        self.stack.insert(s, (index, typ));
397    }
398
399    fn exit_scope(&mut self, compiler: &Compiler) -> VmIndex {
400        let mut count = 0;
401        for x in self.stack.exit_scope() {
402            count += 1;
403            debug!("Pop var: ({:?}, {:?})", x.0, (x.1).0);
404            if self.emit_debug_info && compiler.empty_symbol != x.0 {
405                self.function
406                    .debug_info
407                    .local_map
408                    .close(self.function.instructions.len());
409            }
410        }
411        count
412    }
413}
414
415pub trait CompilerEnv: TypeEnv {
416    fn find_var(&self, id: &Symbol) -> Option<(Variable<Symbol>, ArcType)>;
417}
418
419impl CompilerEnv for TypeInfos {
420    fn find_var(&self, id: &Symbol) -> Option<(Variable<Symbol>, ArcType)> {
421        fn count_function_args(typ: &ArcType) -> VmIndex {
422            match typ.as_function() {
423                Some((_, ret)) => 1 + count_function_args(ret),
424                None => 0,
425            }
426        }
427
428        self.id_to_type
429            .iter()
430            .filter_map(|(_, ref alias)| match **alias.unresolved_type() {
431                Type::Variant(ref row) => row
432                    .row_iter()
433                    .enumerate()
434                    .find(|&(_, field)| field.name == *id),
435                _ => None,
436            })
437            .next()
438            .map(|(tag, field)| {
439                (
440                    Variable::Constructor(tag as VmTag, count_function_args(&field.typ)),
441                    field.typ.clone(),
442                )
443            })
444    }
445}
446
447pub struct Compiler<'a> {
448    globals: &'a (dyn CompilerEnv<Type = ArcType> + 'a),
449    vm: &'a GlobalVmState,
450    symbols: SymbolModule<'a>,
451    stack_types: ScopedMap<Symbol, Alias<Symbol, ArcType>>,
452    source: &'a FileMap,
453    source_name: String,
454    emit_debug_info: bool,
455    empty_symbol: Symbol,
456    hole: ArcType,
457}
458
459impl<'a> KindEnv for Compiler<'a> {
460    fn find_kind(&self, _type_name: &SymbolRef) -> Option<ArcKind> {
461        None
462    }
463}
464
465impl<'a> TypeEnv for Compiler<'a> {
466    type Type = ArcType;
467
468    fn find_type(&self, _id: &SymbolRef) -> Option<ArcType> {
469        None
470    }
471
472    fn find_type_info(&self, id: &SymbolRef) -> Option<Alias<Symbol, ArcType>> {
473        self.stack_types.get(id).cloned()
474    }
475}
476
477impl<'a, T: CompilerEnv> CompilerEnv for &'a T {
478    fn find_var(&self, s: &Symbol) -> Option<(Variable<Symbol>, ArcType)> {
479        (**self).find_var(s)
480    }
481}
482
483impl<'a> Compiler<'a> {
484    pub fn new(
485        globals: &'a (dyn CompilerEnv<Type = ArcType> + 'a),
486        vm: &'a GlobalVmState,
487        mut symbols: SymbolModule<'a>,
488        source: &'a FileMap,
489        source_name: String,
490        emit_debug_info: bool,
491    ) -> Compiler<'a> {
492        Compiler {
493            globals: globals,
494            vm: vm,
495            empty_symbol: symbols.simple_symbol(""),
496            symbols: symbols,
497            stack_types: ScopedMap::new(),
498            source: source,
499            source_name: source_name,
500            emit_debug_info,
501            hole: Type::hole(),
502        }
503    }
504
505    fn intern(&mut self, s: &str) -> Result<InternedStr> {
506        self.vm.intern(s)
507    }
508
509    fn find(&self, id: &Symbol, current: &mut FunctionEnvs) -> Option<Variable<VmIndex>> {
510        current
511            .stack
512            .get(id)
513            .map(|&(index, _)| Stack(index))
514            .or_else(|| {
515                let i = current.envs.len() - 1;
516                let (rest, current) = current.envs.split_at_mut(i);
517                rest.iter()
518                    .rev()
519                    .filter_map(|env| {
520                        env.stack
521                            .get(id)
522                            .map(|&(_, ref typ)| UpVar(current[0].upvar(id, typ)))
523                    })
524                    .next()
525            })
526            .or_else(|| {
527                self.globals
528                    .find_var(&id)
529                    .map(|(variable, typ)| match variable {
530                        Stack(i) => Stack(i),
531                        UpVar(id) => UpVar(current.upvar(&id, &typ)),
532                        Constructor(tag, args) => Constructor(tag, args),
533                    })
534            })
535    }
536
537    fn find_field(&self, typ: &ArcType, field: &Symbol) -> Option<FieldAccess> {
538        // Remove all type aliases to get the actual record type
539        let typ = resolve::remove_aliases_cow(self, &mut NullInterner, typ);
540        let mut iter = typ.remove_forall().row_iter();
541        match iter.by_ref().position(|f| f.name.name_eq(field)) {
542            Some(index) => {
543                for _ in iter.by_ref() {}
544                Some(if **iter.current_type() == Type::EmptyRow {
545                    // Non-polymorphic record, access by index
546                    FieldAccess::Index(index as VmIndex)
547                } else {
548                    FieldAccess::Name
549                })
550            }
551            None => None,
552        }
553    }
554
555    fn find_tag(&self, typ: &ArcType, constructor: &Symbol) -> Option<FieldAccess> {
556        let typ = resolve::remove_aliases_cow(self, &mut NullInterner, typ);
557        self.find_resolved_tag(&typ, constructor)
558    }
559
560    fn find_resolved_tag(&self, typ: &ArcType, constructor: &Symbol) -> Option<FieldAccess> {
561        match **typ {
562            Type::Variant(ref row) => {
563                let mut iter = row.row_iter();
564                match iter.position(|field| field.name.name_eq(constructor)) {
565                    Some(index) => {
566                        for _ in iter.by_ref() {}
567                        Some(if **iter.current_type() == Type::EmptyRow {
568                            // Non-polymorphic variant, access by index
569                            FieldAccess::Index(index as VmIndex)
570                        } else {
571                            FieldAccess::Name
572                        })
573                    }
574                    None => None,
575                }
576            }
577            _ => None,
578        }
579    }
580
581    /// Compiles an expression to a zero argument function which can be directly fed to the
582    /// interpreter
583    pub fn compile_expr(&mut self, expr: CExpr) -> Result<CompiledModule> {
584        let mut env = FunctionEnvs::new();
585        let id = self.empty_symbol.clone();
586        let typ = expr.env_type_of(&self.globals);
587
588        env.start_function(self, 0, id, typ);
589        info!("COMPILING: {}", expr);
590        self.compile(&expr, &mut env, true)?;
591        let current_line = self.source.line_number_at_byte(expr.span().end());
592        let FunctionEnv {
593            function,
594            free_vars,
595            ..
596        } = env.end_function(self, current_line);
597        Ok(CompiledModule {
598            module_globals: free_vars.into_iter().map(|(symbol, _)| symbol).collect(),
599            function,
600        })
601    }
602
603    fn load_identifier(&self, id: &Symbol, function: &mut FunctionEnvs) -> Result<()> {
604        debug!("Load {}", id);
605        match self
606            .find(id, function)
607            .unwrap_or_else(|| ice!("Undefined variable `{:?}` in {}", id, self.source_name,))
608        {
609            Stack(index) => function.emit(Push(index)),
610            UpVar(index) => function.emit(PushUpVar(index)),
611            // Zero argument constructors can be compiled as integers
612            Constructor(tag, 0) => function.emit(ConstructVariant { tag: tag, args: 0 }),
613            Constructor(..) => {
614                return Err(Error::Message(format!(
615                    "Constructor `{}` is not fully applied",
616                    id
617                )));
618            }
619        }
620        Ok(())
621    }
622
623    fn update_line(&mut self, function: &mut FunctionEnvs, expr: CExpr) {
624        // Don't update the current_line for macro expanded code as the lines in that code do not
625        // come from this module
626        if let Some(current_line) = self.source.line_number_at_byte(expr.span().start()) {
627            function.current_line = current_line;
628        }
629    }
630
631    fn compile(
632        &mut self,
633        mut expr: CExpr,
634        function: &mut FunctionEnvs,
635        tail_position: bool,
636    ) -> Result<()> {
637        // Store a stack of expressions which need to be cleaned up after this "tailcall" loop is
638        // done
639        function.stack.enter_scope();
640
641        self.update_line(function, expr);
642
643        while let Some(next) = self.compile_(expr, function, tail_position)? {
644            expr = next;
645            self.update_line(function, expr);
646        }
647        let count = function.exit_scope(self);
648        function.emit(Slide(count));
649        Ok(())
650    }
651
652    fn compile_<'e>(
653        &mut self,
654        expr: CExpr<'e>,
655        function: &mut FunctionEnvs,
656        tail_position: bool,
657    ) -> Result<Option<CExpr<'e>>> {
658        match *expr {
659            Expr::Const(ref lit, _) => match *lit {
660                Literal::Int(i) => function.emit(PushInt(i)),
661                Literal::Byte(b) => function.emit(PushByte(b)),
662                Literal::Float(f) => function.emit(PushFloat(f.into_inner().into())),
663                Literal::String(ref s) => function.emit_string(self.intern(&s)?),
664                Literal::Char(c) => function.emit(PushInt(u32::from(c).into())),
665            },
666            Expr::Ident(ref id, _) => self.load_identifier(&id.name, function)?,
667            Expr::Let(ref let_binding, ref body) => {
668                let stack_start = function.stack_size;
669                // Index where the instruction to create the first closure should be at
670                let first_index = function.function.instructions.len();
671                match let_binding.expr {
672                    core::Named::Expr(ref bind_expr) => {
673                        self.compile(bind_expr, function, false)?;
674                        function.new_stack_var(
675                            self,
676                            let_binding.name.name.clone(),
677                            let_binding.name.typ.clone(),
678                        );
679                    }
680                    core::Named::Recursive(ref closures) => {
681                        for closure in closures.iter() {
682                            // Add the NewClosure/NewRecord instruction before hand
683                            // it will be fixed later
684                            if closure.args.is_empty() {
685                                function.emit(NewRecord { args: 0, record: 0 });
686                            } else {
687                                function.emit(NewClosure {
688                                    function_index: 0,
689                                    upvars: 0,
690                                });
691                            }
692                            function.new_stack_var(
693                                self,
694                                closure.name.name.clone(),
695                                closure.name.typ.clone(),
696                            );
697                        }
698
699                        for (i, closure) in closures.iter().enumerate() {
700                            if let Some(current_line) = self.source.line_number_at_byte(closure.pos)
701                            {
702                                function.current_line = current_line;
703                            }
704                            function.stack.enter_scope();
705
706                            let offset = first_index + i;
707
708                            function.emit(Push(stack_start + i as VmIndex));
709
710                            if closure.args.is_empty() {
711                                self.compile(closure.expr, function, false)?;
712
713                                let construct_index = function
714                                    .function
715                                    .instructions
716                                    .iter()
717                                    .rposition(|inst| match inst {
718                                        Slide(_) |
719                                        Jump(_) => false,
720                                        _ => true,
721                                    }).unwrap_or_else(|| {
722                                        ice!("Expected record as last expression of recursive binding")
723                                    });
724                                match function.function.instructions[construct_index] {
725                                    ConstructRecord { record, args } => {
726                                        function.stack_size -= 1;
727                                        function.function.instructions[offset] =
728                                            NewRecord { record, args };
729                                        function.function.instructions[construct_index] = CloseData { index: stack_start + i as VmIndex };
730                                    }
731                                    ConstructVariant { tag, args } => {
732                                        function.stack_size -= 1;
733                                        function.function.instructions[offset] =
734                                            NewVariant { tag, args };
735                                        function.function.instructions[construct_index] = CloseData { index: stack_start + i as VmIndex };
736                                    }
737                                    x => ice!(
738                                        "Expected record as last expression of recursive binding `{}`: {:?}\n{}", closure.name.name, x, closure.expr
739                                    ),
740                                }
741                            } else {
742                                let (function_index, vars, cf) = self.compile_lambda(
743                                    &closure.name,
744                                    &closure.args,
745                                    &closure.expr,
746                                    function,
747                                )?;
748                                function.function.instructions[offset] = NewClosure {
749                                    function_index: function_index,
750                                    upvars: vars,
751                                };
752                                function.emit(CloseClosure(vars));
753                                function.stack_size -= vars;
754                                function.function.inner_functions.push(cf);
755                            }
756
757                            function.exit_scope(self);
758                        }
759                    }
760                }
761                return Ok(Some(body));
762            }
763            Expr::Call(func, args) => {
764                if let Expr::Ident(ref id, _) = *func {
765                    if is_primitive(&id.name) && id.name.declared_name() != "#error" {
766                        self.compile_primitive(&id.name, args, function, tail_position)?;
767                        return Ok(None);
768                    }
769
770                    if let Some(Constructor(tag, num_args)) = self.find(&id.name, function) {
771                        for arg in args {
772                            self.compile(arg, function, false)?;
773                        }
774                        function.emit(ConstructVariant {
775                            tag: tag,
776                            args: num_args,
777                        });
778                        return Ok(None);
779                    }
780                }
781                self.compile(func, function, false)?;
782                for arg in args.iter() {
783                    self.compile(arg, function, false)?;
784                }
785                function.emit_call(args.len() as VmIndex, tail_position);
786            }
787            Expr::Match(ref scrutinee, ref alts) => {
788                self.compile(scrutinee, function, false)?;
789                // Indexes for each alternative for a successful match to the alternatives code
790                let mut start_jumps = Vec::new();
791                let typ = alts[0].pattern.env_type_of(self);
792                let typ = resolve::remove_aliases_cow(self, &mut NullInterner, typ.remove_forall());
793                // Emit a TestTag + Jump instuction for each alternative which jumps to the
794                // alternatives code if TestTag is sucessesful
795                for alt in alts.iter() {
796                    match alt.pattern {
797                        Pattern::Constructor(ref id, _) => {
798                            let tag = self.find_resolved_tag(&typ, &id.name).unwrap_or_else(|| {
799                                ice!(
800                                    "ICE: Could not find tag for {}::{} when matching on \
801                                     expression:\n{}",
802                                    typ,
803                                    self.symbols.string(&id.name),
804                                    scrutinee
805                                )
806                            });
807
808                            match tag {
809                                FieldAccess::Index(tag) => function.emit(TestTag(tag)),
810                                FieldAccess::Name => {
811                                    let interned = self.intern(id.name.as_ref())?;
812                                    let index = function.add_string_constant(interned);
813                                    function.emit(TestPolyTag(index));
814                                }
815                            }
816
817                            start_jumps.push(function.function.instructions.len());
818                            function.emit(CJump(0));
819                        }
820                        Pattern::Record { .. } => {
821                            start_jumps.push(function.function.instructions.len());
822                        }
823                        Pattern::Ident(_) => {
824                            start_jumps.push(function.function.instructions.len());
825                            function.emit(Jump(0));
826                        }
827                        Pattern::Literal(ref l) => {
828                            let lhs_i = function.stack_size() - 1;
829                            match *l {
830                                Literal::Byte(b) => {
831                                    function.emit(Push(lhs_i));
832                                    function.emit(PushByte(b));
833                                    function.emit(ByteEQ);
834                                }
835                                Literal::Int(i) => {
836                                    function.emit(Push(lhs_i));
837                                    function.emit(PushInt(i));
838                                    function.emit(IntEQ);
839                                }
840                                Literal::Char(ch) => {
841                                    function.emit(Push(lhs_i));
842                                    function.emit(PushInt(u32::from(ch).into()));
843                                    function.emit(IntEQ);
844                                }
845                                Literal::Float(f) => {
846                                    function.emit(Push(lhs_i));
847                                    function.emit(PushFloat(f.into_inner().into()));
848                                    function.emit(FloatEQ);
849                                }
850                                Literal::String(ref s) => {
851                                    let prim_symbol = self.symbols.symbol(SymbolData {
852                                        global: true,
853                                        name: "std.prim",
854                                        location: None,
855                                    });
856                                    self.load_identifier(&prim_symbol, function)?;
857                                    let prim_type = self.globals.find_type(&prim_symbol).unwrap();
858                                    let string_eq_symbol = self.symbols.simple_symbol("string_eq");
859                                    function.emit_field(self, &prim_type, &string_eq_symbol)?;
860                                    let lhs_i = function.stack_size() - 2;
861                                    function.emit(Push(lhs_i));
862                                    function.emit_string(self.intern(&s)?);
863                                    function.emit(Call(2));
864                                }
865                            };
866                            start_jumps.push(function.function.instructions.len());
867                            function.emit(CJump(0));
868                        }
869                    }
870                }
871                // Indexes for each alternative from the end of the alternatives code to code
872                // after the alternative
873                let mut end_jumps = Vec::new();
874                for (alt, &start_index) in alts.iter().zip(start_jumps.iter()) {
875                    function.stack.enter_scope();
876                    match alt.pattern {
877                        Pattern::Constructor(_, ref args) => {
878                            function.function.instructions[start_index] =
879                                CJump(function.function.instructions.len() as VmIndex);
880                            function.emit(Split);
881                            for arg in args.iter() {
882                                function.push_stack_var(self, arg.name.clone(), arg.typ.clone());
883                            }
884                        }
885                        Pattern::Record { .. } => {
886                            let typ = &scrutinee.env_type_of(self);
887                            self.compile_let_pattern(&alt.pattern, typ, function)?;
888                        }
889                        Pattern::Ident(ref id) => {
890                            function.function.instructions[start_index] =
891                                Jump(function.function.instructions.len() as VmIndex);
892                            function.new_stack_var(self, id.name.clone(), id.typ.clone());
893                        }
894                        Pattern::Literal(_) => {
895                            function.function.instructions[start_index] =
896                                CJump(function.function.instructions.len() as VmIndex);
897                            // Add a dummy variable to mark where the literal itself is stored
898                            function.new_stack_var(self, self.empty_symbol.clone(), Type::hole());
899                        }
900                    }
901                    self.compile(&alt.expr, function, tail_position)?;
902                    let count = function.exit_scope(self);
903                    function.emit(Slide(count));
904                    end_jumps.push(function.function.instructions.len());
905                    function.emit(Jump(0));
906                }
907                for &index in end_jumps.iter() {
908                    function.function.instructions[index] =
909                        Jump(function.function.instructions.len() as VmIndex);
910                }
911            }
912            Expr::Data(ref id, exprs, _) => {
913                for expr in exprs {
914                    self.compile(expr, function, false)?;
915                }
916                let typ =
917                    resolve::remove_aliases_cow(self, &mut NullInterner, &id.typ.remove_forall());
918                match **typ.remove_forall() {
919                    Type::Record(_) => {
920                        let index = function.add_record_map(
921                            typ.row_iter().map(|field| field.name.clone()).collect(),
922                        );
923                        function.emit(ConstructRecord {
924                            record: index,
925                            args: exprs.len() as u32,
926                        });
927                    }
928                    Type::App(ref array, _) if **array == Type::Builtin(BuiltinType::Array) => {
929                        function.emit(ConstructArray(exprs.len() as VmIndex));
930                    }
931                    Type::Variant(_) => {
932                        match self.find_tag(&typ, &id.name).unwrap_or_else(|| {
933                            ice!("Variant `{}` not found in {:#?}", id.name, typ)
934                        }) {
935                            FieldAccess::Index(tag) => function.emit(ConstructVariant {
936                                tag,
937                                args: exprs.len() as u32,
938                            }),
939                            FieldAccess::Name => {
940                                let variant_name = self.intern(&id.name.definition_name())?;
941                                let tag = function.add_string_constant(variant_name);
942                                function.emit(ConstructPolyVariant {
943                                    tag,
944                                    args: exprs.len() as u32,
945                                });
946                            }
947                        }
948                    }
949                    _ => ice!("ICE: Unexpected data type for {}: {}", id.name, typ),
950                }
951            }
952
953            Expr::Cast(expr, _) => return Ok(Some(expr)),
954        }
955        Ok(None)
956    }
957
958    fn compile_primitive(
959        &mut self,
960        op: &Symbol,
961        args: &[Expr],
962        function: &mut FunctionEnvs,
963        tail_position: bool,
964    ) -> Result<()> {
965        assert!(args.len() == 2, "Invalid primitive application: {}", op);
966        let lhs = &args[0];
967        let rhs = &args[1];
968        if op.as_str() == "&&" {
969            self.compile(lhs, function, false)?;
970            let lhs_end = function.function.instructions.len();
971            function.emit(CJump(lhs_end as VmIndex + 3)); //Jump to rhs evaluation
972            function.emit(ConstructVariant { tag: 0, args: 0 });
973            function.emit(Jump(0)); //lhs false, jump to after rhs
974                                    // Dont count the integer added added above as the next part of the code never
975                                    // pushed it
976            function.stack_size -= 1;
977            self.compile(rhs, function, tail_position)?;
978            // replace jump instruction
979            function.function.instructions[lhs_end + 2] =
980                Jump(function.function.instructions.len() as VmIndex);
981        } else if op.as_str() == "||" {
982            self.compile(lhs, function, false)?;
983            let lhs_end = function.function.instructions.len();
984            function.emit(CJump(0));
985            self.compile(rhs, function, tail_position)?;
986            function.emit(Jump(0));
987            function.function.instructions[lhs_end] =
988                CJump(function.function.instructions.len() as VmIndex);
989            function.emit(ConstructVariant { tag: 1, args: 0 });
990            // Dont count the integer above
991            function.stack_size -= 1;
992            let end = function.function.instructions.len();
993            function.function.instructions[end - 2] = Jump(end as VmIndex);
994        } else {
995            let instr = match self.symbols.string(op) {
996                "#Int+" => AddInt,
997                "#Int-" => SubtractInt,
998                "#Int*" => MultiplyInt,
999                "#Int/" => DivideInt,
1000                "#Int<" | "#Char<" => IntLT,
1001                "#Int==" | "#Char==" => IntEQ,
1002                "#Byte+" => AddByte,
1003                "#Byte-" => SubtractByte,
1004                "#Byte*" => MultiplyByte,
1005                "#Byte/" => DivideByte,
1006                "#Byte<" => ByteLT,
1007                "#Byte==" => ByteEQ,
1008                "#Float+" => AddFloat,
1009                "#Float-" => SubtractFloat,
1010                "#Float*" => MultiplyFloat,
1011                "#Float/" => DivideFloat,
1012                "#Float<" => FloatLT,
1013                "#Float==" => FloatEQ,
1014                _ => {
1015                    self.load_identifier(op, function)?;
1016                    Call(2)
1017                }
1018            };
1019            self.compile(lhs, function, false)?;
1020            self.compile(rhs, function, false)?;
1021            function.emit(instr);
1022        }
1023        Ok(())
1024    }
1025
1026    fn compile_let_pattern(
1027        &mut self,
1028        pattern: &Pattern,
1029        pattern_type: &ArcType,
1030        function: &mut FunctionEnvs,
1031    ) -> Result<()> {
1032        match *pattern {
1033            Pattern::Ident(ref name) => {
1034                function.new_stack_var(self, name.name.clone(), pattern_type.clone());
1035            }
1036            Pattern::Record { ref fields, .. } => {
1037                let typ = resolve::remove_aliases(
1038                    self,
1039                    &mut NullInterner,
1040                    pattern_type.remove_forall().clone(),
1041                );
1042                let typ = typ.remove_forall();
1043                match **typ {
1044                    Type::Record(_) => {
1045                        let mut field_iter = typ.row_iter();
1046                        let number_of_fields = field_iter.by_ref().count();
1047                        let is_polymorphic = **field_iter.current_type() != Type::EmptyRow;
1048                        if fields.len() == 0
1049                            || (number_of_fields > 4 && number_of_fields / fields.len() >= 4)
1050                            || is_polymorphic
1051                        {
1052                            // For pattern matches on large records where only a few of the fields
1053                            // are used we instead emit a series of GetOffset instructions to avoid
1054                            // pushing a lot of unnecessary fields to the stack
1055                            // Polymorphic records also needs to generate field accesses as `Split`
1056                            // would push the fields in a different order depending on the record
1057
1058                            // Add a dummy variable for the record itself so the correct number
1059                            // of slots are removed when exiting
1060                            function.new_stack_var(
1061                                self,
1062                                self.empty_symbol.clone(),
1063                                self.hole.clone(),
1064                            );
1065
1066                            let record_index = function.stack_size() - 1;
1067                            for pattern_field in fields {
1068                                function.emit(Push(record_index));
1069                                function.emit_field(self, &typ, &pattern_field.0.name)?;
1070
1071                                let field_name = pattern_field
1072                                    .1
1073                                    .as_ref()
1074                                    .unwrap_or(&pattern_field.0.name)
1075                                    .clone();
1076                                function.new_stack_var(
1077                                    self,
1078                                    field_name,
1079                                    pattern_field.0.typ.clone(),
1080                                );
1081                            }
1082                        } else {
1083                            function.emit(Split);
1084                            for field in typ.row_iter() {
1085                                let (name, typ) =
1086                                    match fields.iter().find(|tup| tup.0.name.name_eq(&field.name))
1087                                    {
1088                                        Some(&(ref name, ref bind)) => (
1089                                            bind.as_ref().unwrap_or(&name.name).clone(),
1090                                            field.typ.clone(),
1091                                        ),
1092                                        None => (self.empty_symbol.clone(), self.hole.clone()),
1093                                    };
1094                                function.push_stack_var(self, name, typ);
1095                            }
1096                        }
1097                    }
1098                    _ => ice!("Expected record, got {} at {}", typ, pattern),
1099                }
1100            }
1101            Pattern::Constructor(..) => ice!("constructor pattern in let"),
1102            Pattern::Literal(_) => ice!("literal pattern in let"),
1103        }
1104        Ok(())
1105    }
1106
1107    fn compile_lambda(
1108        &mut self,
1109        id: &TypedIdent,
1110        args: &[TypedIdent],
1111        body: CExpr,
1112        function: &mut FunctionEnvs,
1113    ) -> Result<(VmIndex, VmIndex, CompiledFunction)> {
1114        debug!("Compile function {}", id.name);
1115        function.start_function(self, args.len() as VmIndex, id.name.clone(), id.typ.clone());
1116
1117        function.stack.enter_scope();
1118        for arg in args {
1119            function.push_stack_var(self, arg.name.clone(), arg.typ.clone());
1120        }
1121        self.compile(body, function, true)?;
1122
1123        function.exit_scope(self);
1124
1125        // Insert all free variables into the above globals free variables
1126        // if they arent in that lambdas scope
1127        let current_line = self.source.line_number_at_byte(body.span().end());
1128        let f = function.end_function(self, current_line);
1129        for &(ref var, _) in f.free_vars.iter() {
1130            match self
1131                .find(var, function)
1132                .unwrap_or_else(|| panic!("free_vars: find {}", var))
1133            {
1134                Stack(index) => {
1135                    debug!("Load stack {}", var);
1136                    function.emit(Push(index))
1137                }
1138                UpVar(index) => {
1139                    debug!("Load upvar {}", var);
1140                    function.emit(PushUpVar(index))
1141                }
1142                _ => ice!("Free variables can only be on the stack or another upvar"),
1143            }
1144        }
1145        let function_index = function.function.inner_functions.len() as VmIndex;
1146        let free_vars = f.free_vars.len() as VmIndex;
1147        let FunctionEnv { function, .. } = f;
1148        debug!("End compile function {}", id.name);
1149        Ok((function_index, free_vars, function))
1150    }
1151}
1152
1153#[cfg(all(test, feature = "test"))]
1154mod tests {
1155    use super::*;
1156
1157    use crate::{
1158        base::symbol::Symbols,
1159        core::{grammar::ExprParser, Allocator},
1160        vm::GlobalVmState,
1161    };
1162
1163    fn verify_instructions<'a>(
1164        compiled_function: &CompiledFunction,
1165        instructions: &mut impl Iterator<Item = &'a [Instruction]>,
1166    ) {
1167        assert_eq!(
1168            compiled_function.instructions,
1169            instructions.next().expect("Instructions")
1170        );
1171        for func in &compiled_function.inner_functions {
1172            verify_instructions(func, instructions);
1173        }
1174    }
1175
1176    fn assert_instructions(source: &str, instructions: &[&[Instruction]]) {
1177        let mut symbols = Symbols::new();
1178        let global_allocator = Allocator::new();
1179        let global = ExprParser::new()
1180            .parse(&mut symbols, &global_allocator, source)
1181            .unwrap();
1182
1183        let globals = TypeInfos::new();
1184        let vm_state = GlobalVmState::new();
1185        let source = FileMap::new("".to_string().into(), "".to_string());
1186        let mut compiler = Compiler::new(
1187            &globals,
1188            &vm_state,
1189            SymbolModule::new("test".into(), &mut symbols),
1190            &source,
1191            "test".into(),
1192            false,
1193        );
1194        let module = compiler.compile_expr(&global).unwrap();
1195
1196        verify_instructions(&module.function, &mut instructions.iter().cloned());
1197    }
1198
1199    #[test]
1200    fn recursive_record() {
1201        let _ = ::env_logger::try_init();
1202
1203        assert_instructions(
1204            "rec let a = { b }
1205             rec let b = { a }
1206             in b",
1207            &[&[
1208                NewRecord { args: 1, record: 0 },
1209                NewRecord { args: 1, record: 1 },
1210                Push(0),
1211                Push(1),
1212                CloseData { index: 0 },
1213                Push(1),
1214                Push(0),
1215                CloseData { index: 1 },
1216                Push(1),
1217                Slide(2),
1218                Return,
1219            ]],
1220        )
1221    }
1222
1223    #[test]
1224    fn recursive_record_with_functions() {
1225        let _ = ::env_logger::try_init();
1226
1227        assert_instructions(
1228            "rec let a =
1229                rec let f x = b
1230                in { f }
1231             rec let b =
1232                rec let f y = a
1233                in { f }
1234             in b",
1235            &[
1236                &[
1237                    NewRecord { args: 1, record: 0 },
1238                    NewRecord { args: 1, record: 0 },
1239                    // a
1240                    Push(0),
1241                    NewClosure {
1242                        function_index: 0,
1243                        upvars: 1,
1244                    },
1245                    Push(3),
1246                    Push(1),
1247                    CloseClosure(1),
1248                    Push(3),
1249                    CloseData { index: 0 },
1250                    Slide(1), // Remove closure
1251                    // b
1252                    Push(1),
1253                    NewClosure {
1254                        function_index: 1,
1255                        upvars: 1,
1256                    },
1257                    Push(4),
1258                    Push(0),
1259                    CloseClosure(1),
1260                    Push(4),
1261                    CloseData { index: 1 },
1262                    Slide(1), // Remove closure
1263                    // body
1264                    Push(1),
1265                    Slide(2),
1266                    Return,
1267                ],
1268                &[PushUpVar(0), Return],
1269                &[PushUpVar(0), Return],
1270            ],
1271        )
1272    }
1273}