Skip to main content

cranelift_isle/
codegen.rs

1//! Generate Rust code from a series of Sequences.
2
3use crate::files::Files;
4use crate::sema::{
5    BuiltinType, ExternalSig, ReturnKind, Term, TermEnv, TermId, Type, TypeEnv, TypeId,
6};
7use crate::serialize::{Block, ControlFlow, EvalStep, MatchArm};
8use crate::stablemapset::StableSet;
9use crate::trie_again::{Binding, BindingId, Constraint, RuleSet};
10use std::fmt::Write;
11use std::slice::Iter;
12use std::sync::Arc;
13
14const DEFAULT_MATCH_ARM_BODY_CLOSURE_THRESHOLD: usize = 256;
15
16/// Options for code generation.
17#[derive(Clone, Debug, Default)]
18pub struct CodegenOptions {
19    /// Do not include the `#![allow(...)]` pragmas in the generated
20    /// source. Useful if it must be include!()'d elsewhere.
21    pub exclude_global_allow_pragmas: bool,
22
23    /// Prefixes to remove when printing file names in generaed files. This
24    /// helps keep codegen deterministic.
25    pub prefixes: Vec<Prefix>,
26
27    /// Emit `log::debug!` and `log::trace!` invocations in the generated code to help
28    /// debug rule matching and execution.
29    ///
30    /// In Cranelift this is typically controlled by a cargo feature on the
31    /// crate that includes the generated code (e.g. `cranelift-codegen`).
32    pub emit_logging: bool,
33
34    /// Split large match arms into local closures when generating iterator terms.
35    ///
36    /// In Cranelift this is typically controlled by a cargo feature on the
37    /// crate that includes the generated code (e.g. `cranelift-codegen`).
38    pub split_match_arms: bool,
39
40    /// Threshold for splitting match arms into local closures.
41    ///
42    /// If `None`, a default threshold is used.
43    pub match_arm_split_threshold: Option<usize>,
44}
45
46/// A path prefix which should be replaced when printing file names.
47#[derive(Clone, Debug)]
48pub struct Prefix {
49    /// Prefix to strip
50    pub prefix: String,
51
52    /// Name replacing the stripped prefix.
53    pub name: String,
54}
55
56/// Emit Rust source code for the given type and term environments.
57pub fn codegen(
58    files: Arc<Files>,
59    typeenv: &TypeEnv,
60    termenv: &TermEnv,
61    terms: &[(TermId, RuleSet)],
62    options: &CodegenOptions,
63) -> String {
64    Codegen::compile(files, typeenv, termenv, terms).generate_rust(options)
65}
66
67#[derive(Clone, Debug)]
68struct Codegen<'a> {
69    files: Arc<Files>,
70    typeenv: &'a TypeEnv,
71    termenv: &'a TermEnv,
72    terms: &'a [(TermId, RuleSet)],
73}
74
75enum Nested<'a> {
76    Cases(Iter<'a, EvalStep>),
77    Arms(BindingId, Iter<'a, MatchArm>),
78}
79
80struct BodyContext<'a, W> {
81    out: &'a mut W,
82    ruleset: &'a RuleSet,
83    indent: String,
84    is_ref: StableSet<BindingId>,
85    is_bound: StableSet<BindingId>,
86    term_name: &'a str,
87    emit_logging: bool,
88    split_match_arms: bool,
89    match_arm_split_threshold: Option<usize>,
90
91    // Extra fields for iterator-returning terms.
92    // These fields are used to generate optimized Rust code for iterator-returning terms.
93    /// The number of match splits that have been generated.
94    /// This is used to generate unique names for the match splits.
95    match_split: usize,
96
97    /// The action to take when the iterator overflows.
98    iter_overflow_action: &'static str,
99}
100
101impl<'a, W: Write> BodyContext<'a, W> {
102    fn new(
103        out: &'a mut W,
104        ruleset: &'a RuleSet,
105        term_name: &'a str,
106        emit_logging: bool,
107        split_match_arms: bool,
108        match_arm_split_threshold: Option<usize>,
109        iter_overflow_action: &'static str,
110    ) -> Self {
111        Self {
112            out,
113            ruleset,
114            indent: Default::default(),
115            is_ref: Default::default(),
116            is_bound: Default::default(),
117            term_name,
118            emit_logging,
119            split_match_arms,
120            match_arm_split_threshold,
121            match_split: Default::default(),
122            iter_overflow_action,
123        }
124    }
125
126    fn enter_scope(&mut self) -> StableSet<BindingId> {
127        let new = self.is_bound.clone();
128        std::mem::replace(&mut self.is_bound, new)
129    }
130
131    fn begin_block(&mut self) -> std::fmt::Result {
132        self.indent.push_str("    ");
133        writeln!(self.out, " {{")
134    }
135
136    fn end_block(&mut self, last_line: &str, scope: StableSet<BindingId>) -> std::fmt::Result {
137        if !last_line.is_empty() {
138            writeln!(self.out, "{}{}", &self.indent, last_line)?;
139        }
140        self.is_bound = scope;
141        self.end_block_without_newline()?;
142        writeln!(self.out)
143    }
144
145    fn end_block_without_newline(&mut self) -> std::fmt::Result {
146        self.indent.truncate(self.indent.len() - 4);
147        write!(self.out, "{}}}", &self.indent)
148    }
149
150    fn set_ref(&mut self, binding: BindingId, is_ref: bool) {
151        if is_ref {
152            self.is_ref.insert(binding);
153        } else {
154            debug_assert!(!self.is_ref.contains(&binding));
155        }
156    }
157}
158
159impl<'a> Codegen<'a> {
160    fn compile(
161        files: Arc<Files>,
162        typeenv: &'a TypeEnv,
163        termenv: &'a TermEnv,
164        terms: &'a [(TermId, RuleSet)],
165    ) -> Codegen<'a> {
166        Codegen {
167            files,
168            typeenv,
169            termenv,
170            terms,
171        }
172    }
173
174    fn generate_rust(&self, options: &CodegenOptions) -> String {
175        let mut code = String::new();
176
177        self.generate_header(&mut code, options);
178        self.generate_ctx_trait(&mut code);
179        self.generate_internal_types(&mut code);
180        self.generate_internal_term_constructors(&mut code, options)
181            .unwrap();
182
183        code
184    }
185
186    fn generate_header(&self, code: &mut String, options: &CodegenOptions) {
187        writeln!(code, "// GENERATED BY ISLE. DO NOT EDIT!").unwrap();
188        writeln!(code, "//").unwrap();
189        writeln!(
190            code,
191            "// Generated automatically from the instruction-selection DSL code in:",
192        )
193        .unwrap();
194        for file in &self.files.file_names {
195            writeln!(code, "// - {file}").unwrap();
196        }
197
198        if !options.exclude_global_allow_pragmas {
199            writeln!(
200                code,
201                "\n#![allow(dead_code, unreachable_code, unreachable_patterns)]"
202            )
203            .unwrap();
204            writeln!(
205                code,
206                "#![allow(unused_imports, unused_variables, non_snake_case, unused_mut)]"
207            )
208            .unwrap();
209            writeln!(
210                code,
211                "#![allow(irrefutable_let_patterns, unused_assignments, non_camel_case_types)]"
212            )
213            .unwrap();
214        }
215
216        writeln!(code, "\nuse super::*;  // Pulls in all external types.").unwrap();
217        writeln!(code, "use core::marker::PhantomData;").unwrap();
218    }
219
220    fn generate_trait_sig(&self, code: &mut String, indent: &str, sig: &ExternalSig) {
221        let ret_tuple = format!(
222            "{open_paren}{rets}{close_paren}",
223            open_paren = if sig.ret_tys.len() != 1 { "(" } else { "" },
224            rets = sig
225                .ret_tys
226                .iter()
227                .map(|&ty| self.type_name(ty, /* by_ref = */ false))
228                .collect::<Vec<_>>()
229                .join(", "),
230            close_paren = if sig.ret_tys.len() != 1 { ")" } else { "" },
231        );
232
233        if sig.ret_kind == ReturnKind::Iterator {
234            writeln!(
235                code,
236                "{indent}type {name}_returns: Default + IntoContextIter<Context = Self, Output = {output}>;",
237                indent = indent,
238                name = sig.func_name,
239                output = ret_tuple,
240            )
241            .unwrap();
242        }
243
244        let ret_ty = match sig.ret_kind {
245            ReturnKind::Plain => ret_tuple,
246            ReturnKind::Option => format!("Option<{ret_tuple}>"),
247            ReturnKind::Iterator => format!("()"),
248        };
249
250        writeln!(
251            code,
252            "{indent}fn {name}(&mut self, {params}) -> {ret_ty};",
253            indent = indent,
254            name = sig.func_name,
255            params = sig
256                .param_tys
257                .iter()
258                .enumerate()
259                .map(|(i, &ty)| format!("arg{}: {}", i, self.type_name(ty, /* by_ref = */ true)))
260                .chain(if sig.ret_kind == ReturnKind::Iterator {
261                    Some(format!("returns: &mut Self::{}_returns", sig.func_name))
262                } else {
263                    None
264                })
265                .collect::<Vec<_>>()
266                .join(", "),
267            ret_ty = ret_ty,
268        )
269        .unwrap();
270    }
271
272    fn generate_ctx_trait(&self, code: &mut String) {
273        writeln!(code).unwrap();
274        writeln!(
275            code,
276            "/// Context during lowering: an implementation of this trait"
277        )
278        .unwrap();
279        writeln!(
280            code,
281            "/// must be provided with all external constructors and extractors."
282        )
283        .unwrap();
284        writeln!(
285            code,
286            "/// A mutable borrow is passed along through all lowering logic."
287        )
288        .unwrap();
289        writeln!(code, "pub trait Context {{").unwrap();
290        for term in &self.termenv.terms {
291            if term.has_external_extractor() {
292                let ext_sig = term.extractor_sig(self.typeenv).unwrap();
293                self.generate_trait_sig(code, "    ", &ext_sig);
294            }
295            if term.has_external_constructor() {
296                let ext_sig = term.constructor_sig(self.typeenv).unwrap();
297                self.generate_trait_sig(code, "    ", &ext_sig);
298            }
299        }
300        writeln!(code, "}}").unwrap();
301        writeln!(
302            code,
303            r#"
304pub trait ContextIter {{
305    type Context;
306    type Output;
307    fn next(&mut self, ctx: &mut Self::Context) -> Option<Self::Output>;
308    fn size_hint(&self) -> (usize, Option<usize>) {{ (0, None) }}
309}}
310
311pub trait IntoContextIter {{
312    type Context;
313    type Output;
314    type IntoIter: ContextIter<Context = Self::Context, Output = Self::Output>;
315    fn into_context_iter(self) -> Self::IntoIter;
316}}
317
318pub trait Length {{
319    fn len(&self) -> usize;
320}}
321
322impl<T> Length for alloc::vec::Vec<T> {{
323    fn len(&self) -> usize {{
324        alloc::vec::Vec::len(self)
325    }}
326}}
327
328pub struct ContextIterWrapper<I, C> {{
329    iter: I,
330    _ctx: core::marker::PhantomData<C>,
331}}
332impl<I: Default, C> Default for ContextIterWrapper<I, C> {{
333    fn default() -> Self {{
334        ContextIterWrapper {{
335            iter: I::default(),
336            _ctx: core::marker::PhantomData
337        }}
338    }}
339}}
340impl<I, C> core::ops::Deref for ContextIterWrapper<I, C> {{
341    type Target = I;
342    fn deref(&self) -> &I {{
343        &self.iter
344    }}
345}}
346impl<I, C> core::ops::DerefMut for ContextIterWrapper<I, C> {{
347    fn deref_mut(&mut self) -> &mut I {{
348        &mut self.iter
349    }}
350}}
351impl<I: Iterator, C: Context> From<I> for ContextIterWrapper<I, C> {{
352    fn from(iter: I) -> Self {{
353        Self {{ iter, _ctx: core::marker::PhantomData }}
354    }}
355}}
356impl<I: Iterator, C: Context> ContextIter for ContextIterWrapper<I, C> {{
357    type Context = C;
358    type Output = I::Item;
359    fn next(&mut self, _ctx: &mut Self::Context) -> Option<Self::Output> {{
360        self.iter.next()
361    }}
362    fn size_hint(&self) -> (usize, Option<usize>) {{
363        self.iter.size_hint()
364    }}
365}}
366impl<I: IntoIterator, C: Context> IntoContextIter for ContextIterWrapper<I, C> {{
367    type Context = C;
368    type Output = I::Item;
369    type IntoIter = ContextIterWrapper<I::IntoIter, C>;
370    fn into_context_iter(self) -> Self::IntoIter {{
371        ContextIterWrapper {{
372            iter: self.iter.into_iter(),
373            _ctx: core::marker::PhantomData
374        }}
375    }}
376}}
377impl<T, E: Extend<T>, C> Extend<T> for ContextIterWrapper<E, C> {{
378    fn extend<I: IntoIterator<Item = T>>(&mut self, iter: I) {{
379        self.iter.extend(iter);
380    }}
381}}
382impl<L: Length, C> Length for ContextIterWrapper<L, C> {{
383    fn len(&self) -> usize {{
384        self.iter.len()
385    }}
386}}
387           "#,
388        )
389        .unwrap();
390    }
391
392    fn generate_internal_types(&self, code: &mut String) {
393        for ty in &self.typeenv.types {
394            match ty {
395                &Type::Enum {
396                    name,
397                    is_extern,
398                    is_nodebug,
399                    ref variants,
400                    pos,
401                    ..
402                } if !is_extern => {
403                    let name = &self.typeenv.syms[name.index()];
404                    writeln!(
405                        code,
406                        "\n/// Internal type {}: defined at {}.",
407                        name,
408                        pos.pretty_print_line(&self.files)
409                    )
410                    .unwrap();
411
412                    // Generate the `derive`s.
413                    let debug_derive = if is_nodebug { "" } else { ", Debug" };
414                    if variants.iter().all(|v| v.fields.is_empty()) {
415                        writeln!(code, "#[derive(Copy, Clone, PartialEq, Eq{debug_derive})]")
416                            .unwrap();
417                    } else {
418                        writeln!(code, "#[derive(Clone{debug_derive})]").unwrap();
419                    }
420
421                    writeln!(code, "pub enum {name} {{").unwrap();
422                    for variant in variants {
423                        let name = &self.typeenv.syms[variant.name.index()];
424                        if variant.fields.is_empty() {
425                            writeln!(code, "    {name},").unwrap();
426                        } else {
427                            writeln!(code, "    {name} {{").unwrap();
428                            for field in &variant.fields {
429                                let name = &self.typeenv.syms[field.name.index()];
430                                let ty_name =
431                                    self.typeenv.types[field.ty.index()].name(self.typeenv);
432                                writeln!(code, "        {name}: {ty_name},").unwrap();
433                            }
434                            writeln!(code, "    }},").unwrap();
435                        }
436                    }
437                    writeln!(code, "}}").unwrap();
438                }
439                _ => {}
440            }
441        }
442    }
443
444    fn type_name(&self, typeid: TypeId, by_ref: bool) -> String {
445        match self.typeenv.types[typeid.index()] {
446            Type::Builtin(bt) => String::from(bt.name()),
447            Type::Primitive(_, sym, _) => self.typeenv.syms[sym.index()].clone(),
448            Type::Enum { name, .. } => {
449                let r = if by_ref { "&" } else { "" };
450                format!("{}{}", r, self.typeenv.syms[name.index()])
451            }
452        }
453    }
454
455    fn generate_internal_term_constructors(
456        &self,
457        code: &mut String,
458        options: &CodegenOptions,
459    ) -> std::fmt::Result {
460        for &(termid, ref ruleset) in self.terms.iter() {
461            let root = crate::serialize::serialize(ruleset);
462
463            let termdata = &self.termenv.terms[termid.index()];
464            let term_name = &self.typeenv.syms[termdata.name.index()];
465
466            // Split a match if the term returns an iterator.
467            let mut ctx = BodyContext::new(
468                code,
469                ruleset,
470                term_name,
471                options.emit_logging,
472                options.split_match_arms,
473                options.match_arm_split_threshold,
474                "return;", // At top level, we just return.
475            );
476
477            // Generate the function signature.
478            writeln!(ctx.out)?;
479            writeln!(
480                ctx.out,
481                "{}// Generated as internal constructor for term {}.",
482                &ctx.indent, term_name,
483            )?;
484
485            let sig = termdata.constructor_sig(self.typeenv).unwrap();
486            writeln!(
487                ctx.out,
488                "{}pub fn {}<C: Context>(",
489                &ctx.indent, sig.func_name
490            )?;
491
492            writeln!(ctx.out, "{}    ctx: &mut C,", &ctx.indent)?;
493            for (i, &ty) in sig.param_tys.iter().enumerate() {
494                let (is_ref, ty) = self.ty(ty);
495                write!(ctx.out, "{}    arg{}: ", &ctx.indent, i)?;
496                write!(ctx.out, "{}{}", if is_ref { "&" } else { "" }, ty)?;
497                if let Some(binding) = ctx.ruleset.find_binding(&Binding::Argument {
498                    index: i.try_into().unwrap(),
499                }) {
500                    ctx.set_ref(binding, is_ref);
501                }
502                writeln!(ctx.out, ",")?;
503            }
504
505            let (_, ret) = self.ty(sig.ret_tys[0]);
506
507            if let ReturnKind::Iterator = sig.ret_kind {
508                writeln!(
509                    ctx.out,
510                    "{}    returns: &mut (impl Extend<{}> + Length),",
511                    &ctx.indent, ret
512                )?;
513            }
514
515            write!(ctx.out, "{}) -> ", &ctx.indent)?;
516            match sig.ret_kind {
517                ReturnKind::Iterator => write!(ctx.out, "()")?,
518                ReturnKind::Option => write!(ctx.out, "Option<{ret}>")?,
519                ReturnKind::Plain => write!(ctx.out, "{ret}")?,
520            };
521            // Generating the function signature is done.
522
523            let last_expr = if let Some(EvalStep {
524                check: ControlFlow::Return { .. },
525                ..
526            }) = root.steps.last()
527            {
528                // If there's an outermost fallback, no need for another `return` statement.
529                String::new()
530            } else {
531                match sig.ret_kind {
532                    ReturnKind::Iterator => String::new(),
533                    ReturnKind::Option => "None".to_string(),
534                    ReturnKind::Plain => format!(
535                        "unreachable!(\"no rule matched for term {{}} at {{}}; should it be partial?\", {:?}, {:?})",
536                        term_name,
537                        termdata.decl_pos.pretty_print_line(&self.files)
538                    ),
539                }
540            };
541
542            let scope = ctx.enter_scope();
543            self.emit_block(&mut ctx, &root, sig.ret_kind, &last_expr, scope)?;
544        }
545        Ok(())
546    }
547
548    fn ty(&self, typeid: TypeId) -> (bool, String) {
549        let ty = &self.typeenv.types[typeid.index()];
550        let name = ty.name(self.typeenv);
551        let is_ref = match ty {
552            Type::Builtin(_) | Type::Primitive(..) => false,
553            Type::Enum { .. } => true,
554        };
555        (is_ref, String::from(name))
556    }
557
558    fn validate_block(ret_kind: ReturnKind, block: &Block) -> Nested<'_> {
559        if !matches!(ret_kind, ReturnKind::Iterator) {
560            // Loops are only allowed if we're returning an iterator.
561            assert!(
562                !block
563                    .steps
564                    .iter()
565                    .any(|c| matches!(c.check, ControlFlow::Loop { .. }))
566            );
567
568            // Unless we're returning an iterator, a case which returns a result must be the last
569            // case in a block.
570            if let Some(result_pos) = block
571                .steps
572                .iter()
573                .position(|c| matches!(c.check, ControlFlow::Return { .. }))
574            {
575                assert_eq!(block.steps.len() - 1, result_pos);
576            }
577        }
578
579        Nested::Cases(block.steps.iter())
580    }
581
582    fn block_weight(block: &Block) -> usize {
583        fn cf_weight(cf: &ControlFlow) -> usize {
584            match cf {
585                ControlFlow::Match { arms, .. } => {
586                    arms.iter().map(|a| Codegen::block_weight(&a.body)).sum()
587                }
588                ControlFlow::Equal { body, .. } => Codegen::block_weight(body),
589                ControlFlow::Loop { body, .. } => Codegen::block_weight(body),
590                ControlFlow::Return { .. } => 0,
591            }
592        }
593
594        block.steps.iter().map(|s| 1 + cf_weight(&s.check)).sum()
595    }
596
597    fn emit_block<W: Write>(
598        &self,
599        ctx: &mut BodyContext<W>,
600        block: &Block,
601        ret_kind: ReturnKind,
602        last_expr: &str,
603        scope: StableSet<BindingId>,
604    ) -> std::fmt::Result {
605        ctx.begin_block()?;
606        self.emit_block_contents(ctx, block, ret_kind, last_expr, scope)
607    }
608
609    fn emit_block_contents<W: Write>(
610        &self,
611        ctx: &mut BodyContext<W>,
612        block: &Block,
613        ret_kind: ReturnKind,
614        last_expr: &str,
615        scope: StableSet<BindingId>,
616    ) -> std::fmt::Result {
617        let mut stack = Vec::new();
618        stack.push((Self::validate_block(ret_kind, block), last_expr, scope));
619
620        while let Some((mut nested, last_line, scope)) = stack.pop() {
621            match &mut nested {
622                Nested::Cases(cases) => {
623                    let Some(case) = cases.next() else {
624                        ctx.end_block(last_line, scope)?;
625                        continue;
626                    };
627                    // Iterator isn't done, put it back on the stack.
628                    stack.push((nested, last_line, scope));
629
630                    for &expr in case.bind_order.iter() {
631                        let iter_return = match &ctx.ruleset.bindings[expr.index()] {
632                            Binding::Extractor { term, .. } => {
633                                let termdata = &self.termenv.terms[term.index()];
634                                let sig = termdata.extractor_sig(self.typeenv).unwrap();
635                                if sig.ret_kind == ReturnKind::Iterator {
636                                    if termdata.has_external_extractor() {
637                                        Some(format!("C::{}_returns", sig.func_name))
638                                    } else {
639                                        Some(format!("ContextIterWrapper::<ConstructorVec<_>, _>"))
640                                    }
641                                } else {
642                                    None
643                                }
644                            }
645                            Binding::Constructor { term, .. } => {
646                                let termdata = &self.termenv.terms[term.index()];
647                                let sig = termdata.constructor_sig(self.typeenv).unwrap();
648                                if sig.ret_kind == ReturnKind::Iterator {
649                                    if termdata.has_external_constructor() {
650                                        Some(format!("C::{}_returns", sig.func_name))
651                                    } else {
652                                        Some(format!("ContextIterWrapper::<ConstructorVec<_>, _>"))
653                                    }
654                                } else {
655                                    None
656                                }
657                            }
658                            _ => None,
659                        };
660                        if let Some(ty) = iter_return {
661                            writeln!(
662                                ctx.out,
663                                "{}let mut v{} = {}::default();",
664                                &ctx.indent,
665                                expr.index(),
666                                ty
667                            )?;
668                            write!(ctx.out, "{}", &ctx.indent)?;
669                        } else {
670                            write!(ctx.out, "{}let v{} = ", &ctx.indent, expr.index())?;
671                        }
672                        self.emit_expr(ctx, expr)?;
673                        writeln!(ctx.out, ";")?;
674                        ctx.is_bound.insert(expr);
675                    }
676
677                    match &case.check {
678                        // Use a shorthand notation if there's only one match arm.
679                        ControlFlow::Match { source, arms } if arms.len() == 1 => {
680                            let arm = &arms[0];
681                            let scope = ctx.enter_scope();
682                            match arm.constraint {
683                                Constraint::ConstBool { .. }
684                                | Constraint::ConstInt { .. }
685                                | Constraint::ConstPrim { .. } => {
686                                    write!(ctx.out, "{}if ", &ctx.indent)?;
687                                    self.emit_expr(ctx, *source)?;
688                                    write!(ctx.out, " == ")?;
689                                    self.emit_constraint(ctx, *source, arm)?;
690                                }
691                                Constraint::Variant { .. } | Constraint::Some => {
692                                    write!(ctx.out, "{}if let ", &ctx.indent)?;
693                                    self.emit_constraint(ctx, *source, arm)?;
694                                    write!(ctx.out, " = ")?;
695                                    self.emit_source(ctx, *source, arm.constraint)?;
696                                }
697                            }
698                            ctx.begin_block()?;
699                            stack.push((Self::validate_block(ret_kind, &arm.body), "", scope));
700                        }
701
702                        ControlFlow::Match { source, arms } => {
703                            let scope = ctx.enter_scope();
704                            write!(ctx.out, "{}match ", &ctx.indent)?;
705                            self.emit_source(ctx, *source, arms[0].constraint)?;
706                            ctx.begin_block()?;
707
708                            // Always add a catchall arm, because we
709                            // don't do exhaustiveness checking on the
710                            // match arms.
711                            stack.push((Nested::Arms(*source, arms.iter()), "_ => {}", scope));
712                        }
713
714                        ControlFlow::Equal { a, b, body } => {
715                            let scope = ctx.enter_scope();
716                            write!(ctx.out, "{}if ", &ctx.indent)?;
717                            self.emit_expr(ctx, *a)?;
718                            write!(ctx.out, " == ")?;
719                            self.emit_expr(ctx, *b)?;
720                            ctx.begin_block()?;
721                            stack.push((Self::validate_block(ret_kind, body), "", scope));
722                        }
723
724                        ControlFlow::Loop { result, body } => {
725                            let source = match &ctx.ruleset.bindings[result.index()] {
726                                Binding::Iterator { source } => source,
727                                _ => unreachable!("Loop from a non-Iterator"),
728                            };
729                            let scope = ctx.enter_scope();
730
731                            writeln!(
732                                ctx.out,
733                                "{}let mut v{} = v{}.into_context_iter();",
734                                &ctx.indent,
735                                source.index(),
736                                source.index(),
737                            )?;
738
739                            write!(
740                                ctx.out,
741                                "{}while let Some(v{}) = v{}.next(ctx)",
742                                &ctx.indent,
743                                result.index(),
744                                source.index()
745                            )?;
746                            ctx.is_bound.insert(*result);
747                            ctx.begin_block()?;
748                            stack.push((Self::validate_block(ret_kind, body), "", scope));
749                        }
750
751                        &ControlFlow::Return { pos, result } => {
752                            writeln!(
753                                ctx.out,
754                                "{}// Rule at {}.",
755                                &ctx.indent,
756                                pos.pretty_print_line(&self.files)
757                            )?;
758                            if ctx.emit_logging {
759                                // Produce a valid Rust string literal with escapes.
760                                let pp = pos.pretty_print_line(&self.files);
761                                writeln!(
762                                    ctx.out,
763                                    "{}log::debug!(\"ISLE {{}} {{}}\", {:?}, {:?});",
764                                    &ctx.indent, ctx.term_name, pp
765                                )?;
766                            }
767                            write!(ctx.out, "{}", &ctx.indent)?;
768                            match ret_kind {
769                                ReturnKind::Plain | ReturnKind::Option => {
770                                    write!(ctx.out, "return ")?
771                                }
772                                ReturnKind::Iterator => write!(ctx.out, "returns.extend(Some(")?,
773                            }
774                            self.emit_expr(ctx, result)?;
775                            if ctx.is_ref.contains(&result) {
776                                write!(ctx.out, ".clone()")?;
777                            }
778                            match ret_kind {
779                                ReturnKind::Plain | ReturnKind::Option => writeln!(ctx.out, ";")?,
780                                ReturnKind::Iterator => {
781                                    writeln!(ctx.out, "));")?;
782                                    writeln!(
783                                        ctx.out,
784                                        "{}if returns.len() >= MAX_ISLE_RETURNS {{ {} }}",
785                                        ctx.indent, ctx.iter_overflow_action
786                                    )?;
787                                }
788                            }
789                        }
790                    }
791                }
792
793                Nested::Arms(source, arms) => {
794                    let Some(arm) = arms.next() else {
795                        ctx.end_block(last_line, scope)?;
796                        continue;
797                    };
798                    let source = *source;
799                    // Iterator isn't done, put it back on the stack.
800                    stack.push((nested, last_line, scope));
801
802                    let scope = ctx.enter_scope();
803                    write!(ctx.out, "{}", &ctx.indent)?;
804                    self.emit_constraint(ctx, source, arm)?;
805                    write!(ctx.out, " =>")?;
806                    ctx.begin_block()?;
807
808                    // Compile-time optimization: huge function bodies (often from very large match arms
809                    // of constructor bodies)cause rustc to spend a lot of time in analysis passes.
810                    // Wrap such bodies in a local closure to move the bulk of the work into a separate body
811                    // without needing to know the types of captured locals.
812                    let match_arm_body_closure_threshold = ctx
813                        .match_arm_split_threshold
814                        .unwrap_or(DEFAULT_MATCH_ARM_BODY_CLOSURE_THRESHOLD);
815                    if ctx.split_match_arms
816                        && ret_kind == ReturnKind::Iterator
817                        && Codegen::block_weight(&arm.body) > match_arm_body_closure_threshold
818                    {
819                        let closure_id = ctx.match_split;
820                        ctx.match_split += 1;
821
822                        write!(ctx.out, "{}if (|| -> bool", &ctx.indent)?;
823                        ctx.begin_block()?;
824
825                        let old_overflow_action = ctx.iter_overflow_action;
826                        ctx.iter_overflow_action = "return true;";
827                        let closure_scope = ctx.enter_scope();
828                        self.emit_block_contents(ctx, &arm.body, ret_kind, "false", closure_scope)?;
829                        ctx.iter_overflow_action = old_overflow_action;
830
831                        // Close `if (|| -> bool { ... })()` and stop the outer function on
832                        // iterator-overflow.
833                        writeln!(
834                            ctx.out,
835                            "{})() {{ {} }} // __isle_arm_{}",
836                            &ctx.indent, ctx.iter_overflow_action, closure_id
837                        )?;
838
839                        ctx.end_block("", scope)?;
840                    } else {
841                        stack.push((Self::validate_block(ret_kind, &arm.body), "", scope));
842                    }
843                }
844            }
845        }
846
847        Ok(())
848    }
849
850    fn emit_expr<W: Write>(&self, ctx: &mut BodyContext<W>, result: BindingId) -> std::fmt::Result {
851        if ctx.is_bound.contains(&result) {
852            return write!(ctx.out, "v{}", result.index());
853        }
854
855        let binding = &ctx.ruleset.bindings[result.index()];
856
857        let mut call =
858            |term: TermId,
859             parameters: &[BindingId],
860
861             get_sig: fn(&Term, &TypeEnv) -> Option<ExternalSig>| {
862                let termdata = &self.termenv.terms[term.index()];
863                let sig = get_sig(termdata, self.typeenv).unwrap();
864                if let &[ret_ty] = &sig.ret_tys[..] {
865                    let (is_ref, _) = self.ty(ret_ty);
866                    if is_ref {
867                        ctx.set_ref(result, true);
868                        write!(ctx.out, "&")?;
869                    }
870                }
871                write!(ctx.out, "{}(ctx", sig.full_name)?;
872                debug_assert_eq!(parameters.len(), sig.param_tys.len());
873                for (&parameter, &arg_ty) in parameters.iter().zip(sig.param_tys.iter()) {
874                    let (is_ref, _) = self.ty(arg_ty);
875                    write!(ctx.out, ", ")?;
876                    let (before, after) = match (is_ref, ctx.is_ref.contains(&parameter)) {
877                        (false, true) => ("", ".clone()"),
878                        (true, false) => ("&", ""),
879                        _ => ("", ""),
880                    };
881                    write!(ctx.out, "{before}")?;
882                    self.emit_expr(ctx, parameter)?;
883                    write!(ctx.out, "{after}")?;
884                }
885                if let ReturnKind::Iterator = sig.ret_kind {
886                    write!(ctx.out, ", &mut v{}", result.index())?;
887                }
888                write!(ctx.out, ")")
889            };
890
891        match binding {
892            &Binding::ConstBool { val, .. } => self.emit_bool(ctx, val),
893            &Binding::ConstInt { val, ty } => self.emit_int(ctx, val, ty),
894            Binding::ConstPrim { val } => write!(ctx.out, "{}", &self.typeenv.syms[val.index()]),
895            Binding::Argument { index } => write!(ctx.out, "arg{}", index.index()),
896            Binding::Extractor { term, parameter } => {
897                call(*term, std::slice::from_ref(parameter), Term::extractor_sig)
898            }
899            Binding::Constructor {
900                term, parameters, ..
901            } => call(*term, &parameters[..], Term::constructor_sig),
902
903            Binding::MakeVariant {
904                ty,
905                variant,
906                fields,
907            } => {
908                let (name, variants) = match &self.typeenv.types[ty.index()] {
909                    Type::Enum { name, variants, .. } => (name, variants),
910                    _ => unreachable!("MakeVariant with primitive type"),
911                };
912                let variant = &variants[variant.index()];
913                write!(
914                    ctx.out,
915                    "{}::{}",
916                    &self.typeenv.syms[name.index()],
917                    &self.typeenv.syms[variant.name.index()]
918                )?;
919                if !fields.is_empty() {
920                    ctx.begin_block()?;
921                    for (field, value) in variant.fields.iter().zip(fields.iter()) {
922                        write!(
923                            ctx.out,
924                            "{}{}: ",
925                            &ctx.indent,
926                            &self.typeenv.syms[field.name.index()],
927                        )?;
928                        self.emit_expr(ctx, *value)?;
929                        if ctx.is_ref.contains(value) {
930                            write!(ctx.out, ".clone()")?;
931                        }
932                        writeln!(ctx.out, ",")?;
933                    }
934                    ctx.end_block_without_newline()?;
935                }
936                Ok(())
937            }
938
939            &Binding::MakeSome { inner } => {
940                write!(ctx.out, "Some(")?;
941                self.emit_expr(ctx, inner)?;
942                write!(ctx.out, ")")
943            }
944            &Binding::MatchSome { source } => {
945                self.emit_expr(ctx, source)?;
946                write!(ctx.out, "?")
947            }
948            &Binding::MatchTuple { source, field } => {
949                self.emit_expr(ctx, source)?;
950                write!(ctx.out, ".{}", field.index())
951            }
952
953            // These are not supposed to happen. If they do, make the generated code fail to compile
954            // so this is easier to debug than if we panic during codegen.
955            &Binding::MatchVariant { source, field, .. } => {
956                self.emit_expr(ctx, source)?;
957                write!(ctx.out, ".{} /*FIXME*/", field.index())
958            }
959            &Binding::Iterator { source } => {
960                self.emit_expr(ctx, source)?;
961                write!(ctx.out, ".next() /*FIXME*/")
962            }
963        }
964    }
965
966    fn emit_source<W: Write>(
967        &self,
968        ctx: &mut BodyContext<W>,
969        source: BindingId,
970        constraint: Constraint,
971    ) -> std::fmt::Result {
972        if let Constraint::Variant { .. } = constraint {
973            if !ctx.is_ref.contains(&source) {
974                write!(ctx.out, "&")?;
975            }
976        }
977        self.emit_expr(ctx, source)
978    }
979
980    fn emit_constraint<W: Write>(
981        &self,
982        ctx: &mut BodyContext<W>,
983        source: BindingId,
984        arm: &MatchArm,
985    ) -> std::fmt::Result {
986        let MatchArm {
987            constraint,
988            bindings,
989            ..
990        } = arm;
991        for binding in bindings.iter() {
992            if let &Some(binding) = binding {
993                ctx.is_bound.insert(binding);
994            }
995        }
996        match *constraint {
997            Constraint::ConstBool { val, .. } => self.emit_bool(ctx, val),
998            Constraint::ConstInt { val, ty } => self.emit_int(ctx, val, ty),
999            Constraint::ConstPrim { val } => {
1000                write!(ctx.out, "{}", &self.typeenv.syms[val.index()])
1001            }
1002            Constraint::Variant { ty, variant, .. } => {
1003                let (name, variants) = match &self.typeenv.types[ty.index()] {
1004                    Type::Enum { name, variants, .. } => (name, variants),
1005                    _ => unreachable!("Variant constraint on primitive type"),
1006                };
1007                let variant = &variants[variant.index()];
1008                write!(
1009                    ctx.out,
1010                    "&{}::{}",
1011                    &self.typeenv.syms[name.index()],
1012                    &self.typeenv.syms[variant.name.index()]
1013                )?;
1014                if !bindings.is_empty() {
1015                    ctx.begin_block()?;
1016                    let mut skipped_some = false;
1017                    for (&binding, field) in bindings.iter().zip(variant.fields.iter()) {
1018                        if let Some(binding) = binding {
1019                            write!(
1020                                ctx.out,
1021                                "{}{}: ",
1022                                &ctx.indent,
1023                                &self.typeenv.syms[field.name.index()]
1024                            )?;
1025                            let (is_ref, _) = self.ty(field.ty);
1026                            if is_ref {
1027                                ctx.set_ref(binding, true);
1028                                write!(ctx.out, "ref ")?;
1029                            }
1030                            writeln!(ctx.out, "v{},", binding.index())?;
1031                        } else {
1032                            skipped_some = true;
1033                        }
1034                    }
1035                    if skipped_some {
1036                        writeln!(ctx.out, "{}..", &ctx.indent)?;
1037                    }
1038                    ctx.end_block_without_newline()?;
1039                }
1040                Ok(())
1041            }
1042            Constraint::Some => {
1043                write!(ctx.out, "Some(")?;
1044                if let Some(binding) = bindings[0] {
1045                    ctx.set_ref(binding, ctx.is_ref.contains(&source));
1046                    write!(ctx.out, "v{}", binding.index())?;
1047                } else {
1048                    write!(ctx.out, "_")?;
1049                }
1050                write!(ctx.out, ")")
1051            }
1052        }
1053    }
1054
1055    fn emit_bool<W: Write>(
1056        &self,
1057        ctx: &mut BodyContext<W>,
1058        val: bool,
1059    ) -> Result<(), std::fmt::Error> {
1060        write!(ctx.out, "{val}")
1061    }
1062
1063    fn emit_int<W: Write>(
1064        &self,
1065        ctx: &mut BodyContext<W>,
1066        val: i128,
1067        ty: TypeId,
1068    ) -> Result<(), std::fmt::Error> {
1069        let ty_data = &self.typeenv.types[ty.index()];
1070        match ty_data {
1071            Type::Builtin(BuiltinType::Int(ty)) if ty.is_signed() => write!(ctx.out, "{val}_{ty}"),
1072            Type::Builtin(BuiltinType::Int(ty)) => write!(ctx.out, "{val:#x}_{ty}"),
1073            _ => write!(ctx.out, "{val:#x}"),
1074        }
1075    }
1076}