Skip to main content

nickel_lang_core/term/pattern/
compile.rs

1//! Compilation of pattern matching down to pattern-less Nickel code.
2//!
3//! # Algorithm
4//!
5//! Compiling patterns amounts to generate a decision tree - concretely, a term composed mostly of
6//! nested if-then-else - which either succeeds to match a value and returns the bindings of
7//! pattern variables, or fails and returns `null`.
8//!
9//! Compilation of pattern matching is a well-studied problem in the literature, where efficient
10//! algorithms try to avoid the duplication of checks by "grouping" them in a smart way. A standard
11//! resource on this topic is the paper [_Compiling Pattern Matching to Good Decision
12//! Trees_](https://dl.acm.org/doi/10.1145/1411304.1411311) by Luc Maranget.
13//!
14//! The current version of pattern compilation in Nickel is naive: it simply compiles each pattern
15//! to a checking expression and tries them all until one works. We don't expect pattern matching
16//! to be relevant for performance anytime soon (allegedly, there are much more impacting aspects
17//! to handle before that). We might revisit this in the future if pattern matching turns out to be
18//! a bottleneck.
19//!
20//! Most building blocks are generated programmatically rather than written out as e.g. members of
21//! the [crate::stdlib::internals] module. While clunkier, this makes it easier to change
22//! the compilation strategy in the future and is more efficient in the current setting (combining
23//! building blocks from the standard library would require much more function applications, while
24//! we can generate inlined versions on-the-fly here).
25use std::collections::HashMap;
26
27use indexmap::IndexMap;
28use nickel_lang_parser::{
29    ast::pattern::{
30        ArrayPattern, ConstantPattern, ConstantPatternData, EnumPattern, OrPattern, RecordPattern,
31        TailPattern, bindings::Bindings,
32    },
33    identifier::LocIdent,
34};
35use smallvec::SmallVec;
36
37use crate::{
38    ast::{
39        compat::ToMainline,
40        pattern::{Pattern, PatternData},
41    },
42    error::EvalErrorKind,
43    eval::value::{NickelValue, RecordData},
44    metrics::increment,
45    mk_app,
46    position::{PosIdx, PosTable},
47    term::{
48        BinaryOp, BindingType, FunData, LetAttrs, NAryOp, RecordExtKind, RecordOpKind, Term,
49        UnaryOp, make,
50        record::{Field, FieldMetadata},
51    },
52};
53
54/// A branch of a match expression.
55///
56/// This is a weird mix of `Ast` and `NickelValue`, because it's used during the
57/// AST lowering. Also, `Pattern` doesn't have a runtime representation.
58#[derive(Debug, Clone, PartialEq)]
59pub struct MatchBranch<'ast> {
60    /// The pattern on the left hand side of `=>`.
61    pub pattern: Pattern<'ast>,
62    /// A potential guard, which is an additional side-condition defined as `if cond`. The value
63    /// stored in this field is the boolean condition itself.
64    pub guard: Option<NickelValue>,
65    /// The body of the branch, on the right hand side of `=>`.
66    pub body: NickelValue,
67}
68
69/// Generate a standard `%record/insert%` primop as generated by the parser.
70fn record_insert() -> BinaryOp {
71    BinaryOp::RecordInsert {
72        ext_kind: RecordExtKind::WithValue,
73        metadata: Default::default(),
74        pending_contracts: Default::default(),
75        // We don't really care for optional fields here and we don't need to filter them out
76        op_kind: RecordOpKind::ConsiderAllFields,
77    }
78}
79
80/// Generate a Nickel expression which checks if a field is defined in a record provided as a
81/// variable, and if not, insert a default value. Return the result (either the original record
82/// unchanged, or the original record with the default value). The resulting record is guaranteed
83/// to have the `field` defined. The implementation uses merging to avoid dropping the contracts
84/// and other metadata affected to `field`, if the field exists but has no definition.
85///
86/// More precisely, [with_default_value] generates the following code:
87///
88/// ```nickel
89/// if !(%record/field_is_defined% "<field>" record_id) then
90///   if %record/has_field% "<field>" record_id then
91///     record_id & { "<field>" = default }
92///   else
93///     # Merging is potentially more costly, and we can just fallback to record insertion here.
94///     %record/insert% "<field>" record_id default
95/// else
96///   record_id
97/// ```
98pub(crate) fn with_default_value(
99    pos_table: &mut PosTable,
100    record_id: LocIdent,
101    field: LocIdent,
102    default: NickelValue,
103) -> NickelValue {
104    let field_not_defined = make::op1(
105        UnaryOp::BoolNot,
106        make::op2(
107            BinaryOp::RecordFieldIsDefined(RecordOpKind::ConsiderAllFields),
108            NickelValue::string_posless(field.label()),
109            Term::Var(record_id),
110        ),
111    );
112
113    let has_field = make::op2(
114        BinaryOp::RecordHasField(RecordOpKind::ConsiderAllFields),
115        NickelValue::string_posless(field.label()),
116        Term::Var(record_id),
117    );
118
119    let insert = mk_app!(
120        make::op2(
121            record_insert(),
122            NickelValue::string_posless(field),
123            make::var(record_id)
124        ),
125        default.clone()
126    );
127
128    let inner_let_if = make::if_then_else(
129        has_field,
130        update_with_merge(pos_table, record_id, field, Field::from(default)),
131        insert,
132    );
133
134    make::if_then_else(field_not_defined, inner_let_if, Term::Var(record_id))
135}
136
137/// Update a record field by merging it with a singleton record containing the new value.
138///
139/// ```nickel
140/// record_id & { "<id>" = <field> }
141/// ```
142fn update_with_merge(
143    pos_table: &mut PosTable,
144    record_id: LocIdent,
145    id: LocIdent,
146    field: Field,
147) -> NickelValue {
148    use crate::{ast::MergeKind, label::MergeLabel};
149
150    let annot_spans = field
151        .metadata
152        .iter_annots()
153        .filter_map(|labeled_ty| pos_table.get(labeled_ty.label.span).into_opt());
154    let value_spans = field
155        .value
156        .as_ref()
157        .and_then(|v| pos_table.get(v.pos_idx()).into_opt());
158    let span = annot_spans
159        .chain(value_spans)
160        // We fuse all the definite spans together.
161        // unwrap(): all span should come from the same file
162        .reduce(|span1, span2| span1.fuse(span2).unwrap());
163
164    let singleton = NickelValue::record_posless(RecordData {
165        fields: IndexMap::from_iter([(id, field)]),
166        ..Default::default()
167    });
168    // Right now, patterns are compiled on-the-fly during evaluation. We thus need to
169    // perform the gen_pending_contract transformation manually, or the contracts will
170    // just be ignored. One step suffices, as we create a singleton record that doesn't
171    // contain other non-transformed records (the default value, if any, has been
172    // transformed normally).
173    //
174    // unwrap(): typechecking ensures that there are no unbound variables at this point
175    let singleton =
176        crate::transform::gen_pending_contracts::transform_one(pos_table, singleton).unwrap();
177
178    // Since we generate a non-recursive record and inject it in the evaluation, we must manually
179    // enforce it's properly closurized.
180    let singleton = NickelValue::term_posless(Term::Closurize(singleton));
181
182    let merge_label = MergeLabel {
183        span: pos_table.push(span.into()),
184        kind: MergeKind::Standard,
185    };
186
187    make::op2(
188        BinaryOp::Merge(merge_label),
189        Term::Var(record_id),
190        singleton,
191    )
192}
193
194pub trait CompilePart {
195    /// Compile part of a broader pattern to a Nickel expression.
196    ///
197    /// `value_id` is the value being matched on. The caller of `compile_part`
198    /// is in charge of putting it into the environment.
199    ///
200    /// `match_cont` and `fail_cont` are continuations: `match_cont` is the
201    /// one to use if this pattern matches, while `fail_cont` is the one to
202    /// use if this pattern fails to match. Both `match_cont` and `fail_cont`
203    /// are allowed to contain -- as free variables -- identifiers bound by this
204    /// match expression.
205    ///
206    /// Pattern compilation must avoid shadowing any variables. In particular,
207    /// `compile_part` is not responsible for binding any of the free variables
208    /// in `match_cont` or `fail_cont`.
209    ///
210    /// As a motivating example, consider a match block like
211    ///
212    /// ```nickel
213    /// let foo = 1 in
214    /// [2, 'Blah] |> match {
215    ///   [foo, 'Bar] => foo,
216    ///   [_, 'Blah] => foo,
217    /// }
218    /// ```
219    ///
220    /// This gets compiled by calling `compile_part` on the first match arm,
221    /// with a `fail_cont` continuation that falls through to the second match
222    /// arm. If the first match arm's pattern were to bind `foo` to `2` and
223    /// then fail to match `'Bar`, the second match arm would see the wrong
224    /// binding for `foo`.
225    ///
226    /// To avoid this unwanted shadowing, we accept a `bindings` mapping that
227    /// maps the pattern variables to pattern-local identifiers. In the example
228    /// above, `bindings` would be `{ foo => %1 }`. Pattern compilation can
229    /// safely bind `%1` because it doesn't shadow anything, and the actual
230    /// `foo` binding will be done by the match arm just before entering the
231    /// body.
232    fn compile_part(
233        &self,
234        pos_table: &mut PosTable,
235        value_id: LocIdent,
236        match_cont: NickelValue,
237        fail_cont: NickelValue,
238        bindings: &HashMap<LocIdent, LocIdent>,
239    ) -> NickelValue;
240}
241
242impl<'ast> CompilePart for Pattern<'ast> {
243    // Compilation of the top-level pattern wrapper. If there's an alias,
244    // compiles to
245    //
246    // pattern_data.compile(value_id, let <alias> = value_id in <match_cont>, <fail_cont>)
247    //
248    // where `<alias>` is a generated id, not the actual id bound in the pattern.
249    fn compile_part(
250        &self,
251        pos_table: &mut PosTable,
252        value_id: LocIdent,
253        match_cont: NickelValue,
254        fail_cont: NickelValue,
255        bindings: &HashMap<LocIdent, LocIdent>,
256    ) -> NickelValue {
257        let match_cont = if let Some(alias) = self.alias {
258            make::let_one_in(bindings[&alias], Term::Var(value_id), match_cont)
259        } else {
260            match_cont
261        };
262
263        self.data
264            .compile_part(pos_table, value_id, match_cont, fail_cont, bindings)
265    }
266}
267
268impl<'ast> CompilePart for PatternData<'ast> {
269    fn compile_part(
270        &self,
271        pos_table: &mut PosTable,
272        value_id: LocIdent,
273        match_cont: NickelValue,
274        fail_cont: NickelValue,
275        bindings: &HashMap<LocIdent, LocIdent>,
276    ) -> NickelValue {
277        match self {
278            PatternData::Wildcard => match_cont,
279            PatternData::Any(id) => make::let_one_in(bindings[id], Term::Var(value_id), match_cont),
280            PatternData::Record(pat) => {
281                pat.compile_part(pos_table, value_id, match_cont, fail_cont, bindings)
282            }
283            PatternData::Array(pat) => {
284                pat.compile_part(pos_table, value_id, match_cont, fail_cont, bindings)
285            }
286            PatternData::Enum(pat) => {
287                pat.compile_part(pos_table, value_id, match_cont, fail_cont, bindings)
288            }
289            PatternData::Constant(pat) => {
290                pat.compile_part(pos_table, value_id, match_cont, fail_cont, bindings)
291            }
292            PatternData::Or(pat) => {
293                pat.compile_part(pos_table, value_id, match_cont, fail_cont, bindings)
294            }
295        }
296    }
297}
298
299impl<'ast> CompilePart for ConstantPattern<'ast> {
300    fn compile_part(
301        &self,
302        pos_table: &mut PosTable,
303        value_id: LocIdent,
304        match_cont: NickelValue,
305        fail_cont: NickelValue,
306        bindings: &HashMap<LocIdent, LocIdent>,
307    ) -> NickelValue {
308        self.data
309            .compile_part(pos_table, value_id, match_cont, fail_cont, bindings)
310    }
311}
312
313impl<'ast> CompilePart for ConstantPatternData<'ast> {
314    fn compile_part(
315        &self,
316        _: &mut PosTable,
317        value_id: LocIdent,
318        match_cont: NickelValue,
319        fail_cont: NickelValue,
320        _bindings: &HashMap<LocIdent, LocIdent>,
321    ) -> NickelValue {
322        let compile_constant = |nickel_type: &str, value: NickelValue| {
323            // if %typeof% value_id == '<nickel_type> && value_id == <value> then
324            //   match_cont
325            // else
326            //   fail_cont
327
328            // %typeof% value_id == '<nickel_type>
329            let type_matches = make::op2(
330                BinaryOp::Eq,
331                make::op1(UnaryOp::Typeof, Term::Var(value_id)),
332                NickelValue::enum_tag_posless(nickel_type),
333            );
334
335            // value_id == <value>
336            let value_matches = make::op2(BinaryOp::Eq, Term::Var(value_id), value);
337
338            // <type_matches> && <value_matches>
339            let if_condition = mk_app!(make::op1(UnaryOp::BoolAnd, type_matches), value_matches);
340
341            make::if_then_else(if_condition, match_cont, fail_cont)
342        };
343
344        match self {
345            ConstantPatternData::Bool(b) => {
346                compile_constant("Bool", NickelValue::bool_value_posless(*b))
347            }
348            ConstantPatternData::Number(n) => {
349                compile_constant("Number", NickelValue::number_posless((*n).clone()))
350            }
351            ConstantPatternData::String(s) => {
352                compile_constant("String", NickelValue::string_posless(s.to_owned()))
353            }
354            ConstantPatternData::Null => compile_constant("Other", NickelValue::null()),
355        }
356    }
357}
358
359impl<'ast> CompilePart for OrPattern<'ast> {
360    // Compilation of or patterns.
361    //
362    // We fold over the patterns in reverse order. In the fold, the accumulator is
363    // `fail_cont`, because if an element in the or pattern fails to match then it
364    // falls through to the next case (which is what's in the accumulator because
365    // we're folding in reverse).
366    //
367    // We clone `match_cont` for each alternative. This looks wasteful
368    // when pretty-printing, but `NickelValue` is cheaply clonable so it's
369    // probably fine. Another option would be to add an indirection through
370    // a `Term::Var`.
371    fn compile_part(
372        &self,
373        pos_table: &mut PosTable,
374        value_id: LocIdent,
375        match_cont: NickelValue,
376        fail_cont: NickelValue,
377        bindings: &HashMap<LocIdent, LocIdent>,
378    ) -> NickelValue {
379        self.patterns.iter().rev().fold(fail_cont, |cont, pattern| {
380            pattern.compile_part(pos_table, value_id, match_cont.clone(), cont, bindings)
381        })
382    }
383}
384
385impl<'ast> CompilePart for RecordPattern<'ast> {
386    // Compilation of the top-level record pattern wrapper.
387    //
388    // We check the match shallowly before extracting bindings, ensuring that
389    // all non-defaulted fields are present. If the record isn't open, we also
390    // check that there are no extra fields. We do this just by checking the
391    // record length, because typechecking has already ensured that there are no
392    // duplicate bindings.
393    //
394    // The very first thing the generated code does is to insert default values:
395    // if there are any default values in the pattern, redefine `value_id` to
396    // include them. This is the last part of this function (because we're in a
397    // sort of continuation-passing style), but just keep in mind that the eager
398    // tests are working on a record that already has default values.
399    //
400    // if %typeof% value_id == 'Record
401    //   <if self.patterns has any default value>
402    //     let value_id = <with_default_value(with_default_value(value_id, field0, default0), field1, default1) ...> in
403    //   <end if>
404    //
405    //   <let length_check = if self.is_open { true } else { (%record/fields_with_opts% value_id |> %array/length%) == <self.patterns.len()> }>
406    //
407    //   if (%record/has_field_with_opts% field0 value_id) && (%record/has_field_with_opts% field1 value_id) && ... && <length_check> then
408    //     <fold block (see below)>
409    //   else
410    //     <fail_cont>
411    // else
412    //   <fail_cont>
413    fn compile_part(
414        &self,
415        pos_table: &mut PosTable,
416        value_id: LocIdent,
417        match_cont: NickelValue,
418        fail_cont: NickelValue,
419        bindings: &HashMap<LocIdent, LocIdent>,
420    ) -> NickelValue {
421        // <a> && <b>
422        // (with a small optimization if one is `true`)
423        fn and(a: NickelValue, b: NickelValue) -> NickelValue {
424            if a.is_bool_true() {
425                b
426            } else if b.is_bool_true() {
427                a
428            } else {
429                mk_app!(make::op1(UnaryOp::BoolAnd, a), b)
430            }
431        }
432
433        let local_value_id = LocIdent::fresh();
434
435        // (%record/has_field_with_opts% field0 local_value_id) && (%record/has_field_with_opts% field1 local_value_id) && ...
436        //
437        // Where field0, field1, etc. are all the non-defaulted fields.
438        let has_fields = self.patterns.iter().rev().fold(
439            NickelValue::bool_true(),
440            |has_other_fields, field_pat| {
441                if field_pat.default.is_some() {
442                    has_other_fields
443                } else {
444                    let has_field = make::op2(
445                        BinaryOp::RecordFieldIsDefined(RecordOpKind::ConsiderAllFields),
446                        NickelValue::string_posless(field_pat.matched_id.label()),
447                        Term::Var(local_value_id),
448                    );
449
450                    and(has_field, has_other_fields)
451                }
452            },
453        );
454
455        // If this pattern is non-open, this is
456        //
457        // (%record/fields_with_opts% local_value_id |> %array/length%) == <self.patterns.len()>
458        let has_right_length = if self.is_open() {
459            NickelValue::bool_true()
460        } else {
461            make::op2(
462                BinaryOp::Eq,
463                make::op1(
464                    UnaryOp::ArrayLength,
465                    make::op1(
466                        UnaryOp::RecordFields(RecordOpKind::IgnoreEmptyOpt),
467                        Term::Var(local_value_id),
468                    ),
469                ),
470                NickelValue::number_posless(self.patterns.len()),
471            )
472        };
473
474        let matches_shallow = and(has_fields, has_right_length);
475
476        let match_cont = if let TailPattern::Capture(tail) = self.tail {
477            // let
478            //   <tail> = (%record/remove_with_opts% field0 (%record/remove_with_opts% field1 local_value_id) ... )
479            // in <match_cont>
480            //
481            // where field0, field1, ... are all the fields matched in the pattern, and <tail> is a generated
482            // id, not the actual id bound in the pattern.
483            let tail_value = self.patterns.iter().fold(
484                NickelValue::term_posless(Term::Var(local_value_id)),
485                |tail, field_pat| {
486                    make::op2(
487                        BinaryOp::RecordRemove(RecordOpKind::ConsiderAllFields),
488                        NickelValue::string_posless(field_pat.matched_id.label()),
489                        tail,
490                    )
491                },
492            );
493            make::let_one_in(bindings[&tail], tail_value, match_cont)
494        } else {
495            match_cont
496        };
497
498        // The main fold block where we extract all of the bound variables.
499        // Because there's an eager check before we enter this block, we
500        // can't fail to extract any of the variables. However, there could be
501        // sub-patterns that fail to match.
502        //
503        // Our accumulator is `match_cont`, the continuation to take when
504        // the match succeeds.
505        //
506        // There are two cases in the match block, depending on whether the
507        // field in the pattern has an annotation. If it doesn't, we yield
508        //
509        // let <field> = %static_access(<field>)% local_value_id in
510        // <field_pat.pattern.compile_part(field, match_cont, fail_cont)>
511        //
512        // If there is an annotation, we yield
513        //
514        // let value_id = <update_with_merge value_id field annotation> in
515        // let <field> = %static_access(<field>)% local_value_id in
516        // <field_pat.pattern.compile_part(field, match_cont, fail_cont)>
517        let fold_block = self
518            .patterns
519            .iter()
520            .rev()
521            .fold(match_cont, |match_cont, field_pat| {
522                let local_field_id = LocIdent::fresh();
523                let field = field_pat.matched_id;
524
525                let match_cont = field_pat.pattern.compile_part(
526                    pos_table,
527                    local_field_id,
528                    match_cont,
529                    fail_cont.clone(),
530                    bindings,
531                );
532
533                // %static_access(field)% value_id
534                let extracted_value =
535                    make::op1(UnaryOp::RecordAccess(field), Term::Var(local_value_id));
536
537                // let <local_field_id> = <extracted_value> in <match_cont>
538                let match_cont = make::let_one_in(local_field_id, extracted_value, match_cont);
539
540                // <if !field.annotation.is_empty()>
541                //   let value_id = <update_with_merge...> in <match_cont>
542                // <else>
543                //   <match_cont>
544                // <end if>
545                if !field_pat.annotation.is_empty() {
546                    let annotation = field_pat.annotation.to_mainline(pos_table);
547                    make::let_one_in(
548                        local_value_id,
549                        update_with_merge(
550                            pos_table,
551                            local_value_id,
552                            field,
553                            Field::from(FieldMetadata {
554                                annotation,
555                                ..Default::default()
556                            }),
557                        ),
558                        match_cont,
559                    )
560                } else {
561                    match_cont
562                }
563            });
564
565        // if <matches_shallow> then <fold_block> else <fail_cont>
566        let inner_if = make::if_then_else(matches_shallow, fold_block, fail_cont.clone());
567
568        // %typeof% value_id == 'Record
569        let is_record: NickelValue = make::op2(
570            BinaryOp::Eq,
571            make::op1(UnaryOp::Typeof, Term::Var(value_id)),
572            NickelValue::enum_tag_posless("Record"),
573        );
574
575        // <if any default values {>
576        //   let local_value_id = <with_default_value(with_default_value(value_id, field0, default0), field1, default1) ...> in match_cont
577        // <} else {>
578        //   match_cont
579        // <}>
580        //
581        // where field0, default0, etc. range over the fields and their default values.
582        let with_defaults = self.patterns.iter().fold(inner_if, |cont, field_pat| {
583            if let Some(default) = field_pat.default.as_ref() {
584                let default = default.to_mainline(pos_table);
585                make::let_one_in(
586                    local_value_id,
587                    with_default_value(pos_table, local_value_id, field_pat.matched_id, default),
588                    cont,
589                )
590            } else {
591                cont
592            }
593        });
594
595        // if <is_record> then let local_value_id = value_id in <with_defaults> else <fail_cont>
596        make::if_then_else(
597            is_record,
598            make::let_one_in(local_value_id, Term::Var(value_id), with_defaults),
599            fail_cont,
600        )
601    }
602}
603
604impl<'ast> CompilePart for ArrayPattern<'ast> {
605    // Compilation of an array pattern.
606    //
607    // let value_len = %array/length% value_id in
608    //
609    // <if self.is_open()>
610    // if %typeof% value_id == 'Array && value_len >= <self.patterns.len()>
611    // <else>
612    // if %typeof% value_id == 'Array && value_len == <self.patterns.len()>
613    // <end if>
614    //   <fold block (see below)>
615    // else
616    //   <fail_cont>
617    fn compile_part(
618        &self,
619        pos_table: &mut PosTable,
620        value_id: LocIdent,
621        match_cont: NickelValue,
622        fail_cont: NickelValue,
623        bindings: &HashMap<LocIdent, LocIdent>,
624    ) -> NickelValue {
625        let value_len_id = LocIdent::fresh();
626        let pats_len = NickelValue::number_posless(self.patterns.len());
627
628        let match_cont = if let TailPattern::Capture(rest) = self.tail {
629            // let <rest> = %array/slice% <self.patterns.len()> value_len value_id in <match_cont>
630            make::let_one_in(
631                bindings[&rest],
632                make::opn(
633                    NAryOp::ArraySlice,
634                    vec![
635                        pats_len.clone(),
636                        Term::Var(value_len_id).into(),
637                        Term::Var(value_id).into(),
638                    ],
639                ),
640                match_cont,
641            )
642        } else {
643            match_cont
644        };
645
646        // <fold (idx) in 0..self.patterns.len()
647        //  - match_cont is the accumulator
648        //  - initial accumulator is `match_cont`, which includes the tail bindings if necessary
649        //  >
650        //
651        //   let local_value_id = %array/at% <idx> value_id in
652        //   <self.patterns[idx].compile_part(local_value_id, match_cont, fail_cont)>
653        //
654        // <end fold>
655        let fold_block: NickelValue = self.patterns.iter().enumerate().rev().fold(
656            match_cont,
657            |match_cont, (idx, elem_pat)| {
658                let local_value_id = LocIdent::fresh();
659
660                // %array/at% idx value_id
661                let extracted_value = make::op2(
662                    BinaryOp::ArrayAt,
663                    Term::Var(value_id),
664                    NickelValue::number_posless(idx),
665                );
666
667                make::let_one_in(
668                    local_value_id,
669                    extracted_value,
670                    elem_pat.compile_part(
671                        pos_table,
672                        local_value_id,
673                        match_cont,
674                        fail_cont.clone(),
675                        bindings,
676                    ),
677                )
678            },
679        );
680
681        // %typeof% value_id == 'Array
682        let is_array: NickelValue = make::op2(
683            BinaryOp::Eq,
684            make::op1(UnaryOp::Typeof, Term::Var(value_id)),
685            NickelValue::enum_tag_posless("Array"),
686        );
687
688        let comp_op = if self.is_open() {
689            BinaryOp::GreaterOrEq
690        } else {
691            BinaryOp::Eq
692        };
693
694        // <is_array> && value_len <comp_op> <self.patterns.len()>
695        let outer_check = mk_app!(
696            make::op1(UnaryOp::BoolAnd, is_array),
697            make::op2(comp_op, Term::Var(value_len_id), pats_len)
698        );
699
700        // if <outer_check> then <fold_block> else <fail_cont>
701        let outer_if = make::if_then_else(outer_check, fold_block, fail_cont);
702
703        // let <value_len_id> = %array/length% <value_id> in <outer_if>
704        make::let_one_in(
705            value_len_id,
706            make::op1(UnaryOp::ArrayLength, Term::Var(value_id)),
707            outer_if,
708        )
709    }
710}
711
712impl<'ast> CompilePart for EnumPattern<'ast> {
713    fn compile_part(
714        &self,
715        pos_table: &mut PosTable,
716        value_id: LocIdent,
717        match_cont: NickelValue,
718        fail_cont: NickelValue,
719        bindings: &HashMap<LocIdent, LocIdent>,
720    ) -> NickelValue {
721        // %enum/get_tag% value_id == '<self.tag>
722        let tag_matches = make::op2(
723            BinaryOp::Eq,
724            make::op1(UnaryOp::EnumGetTag, Term::Var(value_id)),
725            NickelValue::enum_tag_posless(self.tag),
726        );
727
728        if let Some(pat) = &self.pattern {
729            // if %enum/is_variant% value_id && %enum/get_tag% value_id == '<self.tag> then
730            //   let next_value_id = %enum/get_arg% value_id in
731            //   <pattern.compile(next_value_id, match_cont, fail_cont)>
732            // else
733            //   fail_cont
734
735            // %enum/is_variant% value_id && <tag_matches>
736            let if_condition = mk_app!(
737                make::op1(
738                    UnaryOp::BoolAnd,
739                    make::op1(UnaryOp::EnumIsVariant, Term::Var(value_id)),
740                ),
741                tag_matches
742            );
743
744            let next_value_id = LocIdent::fresh();
745
746            make::if_then_else(
747                if_condition,
748                make::let_one_in(
749                    next_value_id,
750                    make::op1(UnaryOp::EnumGetArg, Term::Var(value_id)),
751                    pat.compile_part(
752                        pos_table,
753                        next_value_id,
754                        match_cont,
755                        fail_cont.clone(),
756                        bindings,
757                    ),
758                ),
759                fail_cont,
760            )
761        } else {
762            // if %typeof% value_id == 'Enum && !(%enum/is_variant% value_id) && <tag_matches> then
763            //   match_cont
764            // else
765            //   fail_cont
766
767            // %typeof% value_id == 'Enum
768            let is_enum = make::op2(
769                BinaryOp::Eq,
770                make::op1(UnaryOp::Typeof, Term::Var(value_id)),
771                NickelValue::enum_tag_posless("Enum"),
772            );
773
774            // !(%enum/is_variant% value_id)
775            let is_enum_tag = make::op1(
776                UnaryOp::BoolNot,
777                make::op1(UnaryOp::EnumIsVariant, Term::Var(value_id)),
778            );
779
780            // <is_enum> && <is_enum_tag> && <tag_matches>
781            let if_condition = mk_app!(
782                make::op1(UnaryOp::BoolAnd, is_enum,),
783                mk_app!(make::op1(UnaryOp::BoolAnd, is_enum_tag,), tag_matches)
784            );
785
786            make::if_then_else(if_condition, match_cont, fail_cont)
787        }
788    }
789}
790
791pub trait Compile {
792    /// Compile a match expression to a Nickel expression matching `value`.
793    fn compile(self, pos_table: &mut PosTable, value: NickelValue, pos_idx: PosIdx) -> NickelValue;
794}
795
796/// Content of a match expression.
797#[derive(Debug, Clone, PartialEq)]
798pub struct MatchData<'ast> {
799    pub branches: Vec<MatchBranch<'ast>>,
800}
801
802impl<'ast> Compile for MatchData<'ast> {
803    // Compilation of a full match expression, represented as an array of
804    // its branches. Code between < and > is Rust code, think of it as a kind
805    // of templating. Note that some special cases compile differently as
806    // optimizations.
807    //
808    // let value_id = value in <fold_block>
809    //
810    // The main action is in <fold_block> (see below), which folds over the match branches.
811    fn compile(
812        mut self,
813        pos_table: &mut PosTable,
814        value: NickelValue,
815        pos_idx: PosIdx,
816    ) -> NickelValue {
817        increment!("pattern_compile");
818
819        if self.branches.iter().all(|branch| {
820            // While we could get something working even with a guard, it's a bit more work and
821            // there's no current incentive to do so (a guard on a tags-only match is arguably less
822            // common, as such patterns don't bind any variable). For the time being, we just
823            // exclude guards from the tags-only optimization.
824            matches!(
825                branch.pattern.data,
826                PatternData::Enum(EnumPattern { pattern: None, .. }) | PatternData::Wildcard
827            ) && branch.guard.is_none()
828        }) {
829            let wildcard_pat = self.branches.iter().enumerate().find_map(
830                |(
831                    idx,
832                    MatchBranch {
833                        pattern,
834                        guard,
835                        body,
836                    },
837                )| {
838                    if matches!((&pattern.data, guard), (PatternData::Wildcard, None)) {
839                        Some((idx, body.clone()))
840                    } else {
841                        None
842                    }
843                },
844            );
845
846            // If we find a wildcard pattern, we record its index in order to discard all the
847            // patterns coming after the wildcard, because they are unreachable.
848            let default = if let Some((idx, body)) = wildcard_pat {
849                self.branches.truncate(idx + 1);
850                Some(body)
851            } else {
852                None
853            };
854
855            let tags_only = self
856                .branches
857                .into_iter()
858                .filter_map(
859                    |MatchBranch {
860                         pattern,
861                         guard: _,
862                         body,
863                     }| {
864                        if let PatternData::Enum(EnumPattern { tag, .. }) = pattern.data {
865                            Some((*tag, body))
866                        } else {
867                            None
868                        }
869                    },
870                )
871                .collect();
872
873            return TagsOnlyMatch {
874                branches: tags_only,
875                default,
876            }
877            .compile(pos_table, value, pos_idx);
878        }
879
880        let error_case = NickelValue::term(
881            Term::RuntimeError(Box::new(EvalErrorKind::NonExhaustiveMatch {
882                value: value.clone(),
883                pos: pos_idx,
884            })),
885            pos_idx,
886        );
887
888        let value_id = LocIdent::fresh();
889
890        // The fold block.
891        //
892        // As in or pattern compilation, we fold in reverse over the
893        // alternatives, and our accumulator is the failure continuation. The
894        // initial accumulator is a runtime error.
895        //
896        // This is the point at which we actually bind the variables from the
897        // pattern. We start by generating a map of pattern-bound variables
898        // to fresh variables (e.g. `{ foo => %1, bar => %2 }`) if `foo` and
899        // `bar` are the variables bound by the current pattern. We pass
900        // this map to the pattern compilation so that it will bind `%1` to
901        // the expression that we'll eventually bind `foo` to.
902        // See `CompilePart::compile_part` for why we do this.
903        //
904        // <for branch in branches.rev()
905        //  - fail_cont is the accumulator
906        // >
907        //    <let match_cont = if there's a guard {
908        //      let foo = %1, bar = %2, <...other bindings> in if <guard> then match_cont else fail_cont
909        //    } else {
910        //      let foo = %1, bar = %2, <...other bindings> in match_cont
911        //    }
912        //    >
913        //    <pattern.compile_part(value_id, match_cont, fail_cont)>
914        let fold_block = self
915            .branches
916            .into_iter()
917            .rev()
918            .fold(error_case, |fail_cont, branch| {
919                let bindings = branch
920                    .pattern
921                    .bindings()
922                    .iter()
923                    .map(|b| (b.id, LocIdent::fresh()))
924                    .collect::<HashMap<_, _>>();
925                let match_cont = if let Some(guard) = branch.guard {
926                    // The guard expression becomes part of the match
927                    // continuation, so it will be evaluated in the same
928                    // environment as the body of the branch.
929                    make::if_then_else(guard, branch.body, fail_cont.clone())
930                } else {
931                    branch.body
932                };
933
934                let match_cont = make::let_in(
935                    false,
936                    bindings
937                        .iter()
938                        .map(|(binding_id, fresh_id)| (*binding_id, Term::Var(*fresh_id))),
939                    match_cont,
940                );
941                branch
942                    .pattern
943                    .compile_part(pos_table, value_id, match_cont, fail_cont, &bindings)
944            });
945
946        // let value_id = value in <fold_block>
947        make::let_one_in(value_id, value, fold_block)
948    }
949}
950
951/// Simple wrapper used to implement specialization of match statements when all of the patterns
952/// are enum tags. Instead of a sequence of conditionals (which has linear time complexity), we use
953/// a special primops based on a hashmap, which has amortized constant time complexity.
954struct TagsOnlyMatch {
955    branches: Vec<(LocIdent, NickelValue)>,
956    default: Option<NickelValue>,
957}
958
959impl Compile for TagsOnlyMatch {
960    fn compile(self, _: &mut PosTable, value: NickelValue, pos_idx: PosIdx) -> NickelValue {
961        increment!("pattern_comile(tags_only_match)");
962
963        // We simply use the corresponding specialized primop in that case.
964        let match_op = mk_app!(
965            make::op1(
966                UnaryOp::TagsOnlyMatch {
967                    has_default: self.default.is_some()
968                },
969                value
970            )
971            .with_pos_idx(pos_idx),
972            NickelValue::record_posless(RecordData::with_field_values(self.branches.into_iter()))
973        );
974
975        let match_op = if let Some(default) = self.default {
976            mk_app!(match_op, default)
977        } else {
978            match_op
979        };
980
981        match_op.with_pos_idx(pos_idx)
982    }
983}
984
985/// Compile a destructuring let-binding into a `Term` (which has no destructuring let)
986///
987/// A let-binding
988///
989/// ```text
990/// let
991///   <pat1> = <bound1>,
992///   <pat2> = <bound2>
993/// in body
994/// ```
995///
996/// in which `<pat1>` binds `foo` and `<pat2>` binds `bar` is desugared to
997///
998/// ```text
999/// let
1000///   %b1 = <bound1>,
1001///   %b2 = <bound2>,
1002/// in
1003/// let
1004///   %r1 = <pat1.compile_part(%b1, { foo = %c1 }, <error>, <{ foo => %c1 }>)>,
1005///   %r2 = <pat2.compile_part(%b2, { baz = %c2 }, <error>), <{ bar => %c2 }>)>,
1006/// in
1007/// let
1008///   foo = %r1.foo,
1009///   ...
1010///   baz = %r2.baz,
1011/// in (%seq% %r1) (%seq% %r2) body
1012/// ```
1013/// where `foo` and `baz` are names bound in `<pat1>` and `<pat2>`.
1014///
1015/// There's some ambiguity about where to put the error-checking. It might be natural
1016/// to put it before trying to access `%r1.foo`, but that would only raise the error
1017/// if someone tries to evaluate `foo`. Putting it in the body as above raises
1018/// an error, for example, in `let 'Foo = 'Bar in true`.
1019///
1020/// A recursive let-binding is desugared almost the same way, except that everything is
1021/// shoved into a single let-rec block instead of three nested blocks.
1022pub fn compile_let_pattern<'ast>(
1023    pos_table: &mut PosTable,
1024    bindings: &[(&Pattern<'ast>, NickelValue)],
1025    body: NickelValue,
1026    attrs: LetAttrs,
1027) -> Term {
1028    // Outer bindings are the ones we called %b1 and %b2, and %empty_record_id in the doc above.
1029    let mut outer_bindings = SmallVec::new();
1030    // Mid bindings are the ones we called %r1 and %r2 above.
1031    let mut mid_bindings = SmallVec::new();
1032    // Inner bindings are the ones that bind the actual variables defined in the patterns.
1033    let mut inner_bindings = SmallVec::new();
1034
1035    for (pat, rhs) in bindings {
1036        let fused_pos = pat.pos.fuse(pos_table.get(rhs.pos_idx()));
1037        let error_case = NickelValue::term(
1038            Term::RuntimeError(Box::new(EvalErrorKind::FailedDestructuring {
1039                value: rhs.clone(),
1040                pattern_pos: pos_table.push(pat.pos),
1041            })),
1042            pos_table.push(fused_pos),
1043        );
1044
1045        let outer_id = LocIdent::fresh();
1046        outer_bindings.push((outer_id, rhs.clone()));
1047
1048        let mid_id = LocIdent::fresh();
1049
1050        let mut pattern_bindings = HashMap::new();
1051        for binding in pat.bindings() {
1052            let id = binding.id;
1053            inner_bindings.push((
1054                id,
1055                make::static_access(Term::Var(mid_id), std::iter::once(id)),
1056            ));
1057            pattern_bindings.insert(id, LocIdent::fresh());
1058        }
1059
1060        // Build the mapping from pattern variables to fresh variables
1061        // (see `CompilePart::compile_part` for details). This corresponds
1062        // to the <{ foo => %c1 }> part of the documentation above.
1063        let bindings_record = NickelValue::record_posless(RecordData {
1064            fields: pattern_bindings
1065                .iter()
1066                .map(|(bound_id, fresh_id)| {
1067                    (*bound_id, NickelValue::from(Term::Var(*fresh_id)).into())
1068                })
1069                .collect(),
1070            attrs: Default::default(),
1071            sealed_tail: None,
1072        });
1073
1074        mid_bindings.push((
1075            mid_id,
1076            pat.compile_part(
1077                pos_table,
1078                outer_id,
1079                bindings_record,
1080                error_case,
1081                &pattern_bindings,
1082            ),
1083        ));
1084    }
1085
1086    // Force all the "mid" ids, to make pattern failures more eager.
1087    // Without this, `let 'Foo = 'Bar in 1` wouldn't fail.
1088    let checked_body = mid_bindings.iter().rev().fold(body, |acc, (id, _)| {
1089        mk_app!(make::op1(UnaryOp::Seq, Term::Var(*id)), acc)
1090    });
1091
1092    let attrs = LetAttrs {
1093        binding_type: BindingType::Normal,
1094        rec: attrs.rec,
1095    };
1096    if attrs.rec {
1097        Term::let_in(
1098            outer_bindings
1099                .into_iter()
1100                .chain(mid_bindings)
1101                .chain(inner_bindings)
1102                .collect(),
1103            checked_body,
1104            attrs,
1105        )
1106    } else {
1107        Term::let_in(
1108            outer_bindings,
1109            Term::let_in(
1110                mid_bindings,
1111                Term::let_in(inner_bindings, checked_body, attrs.clone()).into(),
1112                attrs.clone(),
1113            )
1114            .into(),
1115            attrs,
1116        )
1117    }
1118}
1119
1120/// Compile a destructuring function into a `Term` (which has no destructuring function).
1121///
1122/// A function `fun <pat> => body` is desugared to `fun x => let <pat> = x in body`, and then we compile
1123/// the destructuring let.
1124pub fn compile_fun_pattern<'ast>(
1125    pos_table: &mut PosTable,
1126    pattern: &Pattern<'ast>,
1127    body: NickelValue,
1128) -> Term {
1129    let id = pattern.alias.unwrap_or_else(LocIdent::fresh);
1130    let wrapped_body = compile_let_pattern(
1131        pos_table,
1132        &[(pattern, Term::Var(id).into())],
1133        body,
1134        LetAttrs::default(),
1135    );
1136    Term::Fun(FunData {
1137        arg: id,
1138        body: wrapped_body.into(),
1139    })
1140}