midenc_codegen_masm/masm/
region.rs

1use std::{collections::BTreeSet, fmt};
2
3use cranelift_entity::PrimaryMap;
4use miden_assembly::ast;
5use midenc_hir::{diagnostics::Span, formatter::PrettyPrint, FunctionIdent, Ident};
6use smallvec::smallvec;
7
8use super::*;
9use crate::InstructionPointer;
10
11/// This struct represents a region of code in Miden Assembly.
12///
13/// A region is a tree of blocks with isolated scope. In many
14/// ways a [Region] is basically a [Function], just without any
15/// identity. Additionally, a [Region] does not have local variables,
16/// those must be provided by a parent [Function].
17///
18/// In short, this represents both the body of a function, and the
19/// body of a `begin` block in Miden Assembly.
20#[derive(Debug, Clone)]
21pub struct Region {
22    pub body: BlockId,
23    pub blocks: PrimaryMap<BlockId, Block>,
24}
25impl Default for Region {
26    fn default() -> Self {
27        let mut blocks = PrimaryMap::<BlockId, Block>::default();
28        let id = blocks.next_key();
29        let body = blocks.push(Block {
30            id,
31            ops: smallvec![],
32        });
33        Self { body, blocks }
34    }
35}
36impl Region {
37    /// Get the [BlockId] for the block which forms the body of this region
38    #[inline(always)]
39    pub const fn id(&self) -> BlockId {
40        self.body
41    }
42
43    /// Get a reference to a [Block] by [BlockId]
44    #[inline]
45    pub fn block(&self, id: BlockId) -> &Block {
46        &self.blocks[id]
47    }
48
49    /// Get a mutable reference to a [Block] by [BlockId]
50    #[inline]
51    pub fn block_mut(&mut self, id: BlockId) -> &mut Block {
52        &mut self.blocks[id]
53    }
54
55    /// Get the instruction under `ip`, if valid
56    pub fn get(&self, ip: InstructionPointer) -> Option<Span<Op>> {
57        self.blocks[ip.block].ops.get(ip.index).copied()
58    }
59
60    /// Allocate a new code block in this region
61    pub fn create_block(&mut self) -> BlockId {
62        let id = self.blocks.next_key();
63        self.blocks.push(Block {
64            id,
65            ops: smallvec![],
66        });
67        id
68    }
69
70    /// Render the code in this region as Miden Assembly, at the specified indentation level (in
71    /// units of 4 spaces)
72    pub fn display<'a, 'b: 'a>(
73        &'b self,
74        function: Option<FunctionIdent>,
75        imports: &'b ModuleImportInfo,
76    ) -> DisplayRegion<'a> {
77        DisplayRegion {
78            region: self,
79            function,
80            imports,
81        }
82    }
83
84    /// Convert this [Region] to a [miden_assembly::ast::Block] using the provided
85    /// local/external function maps to handle calls present in the body of the region.
86    pub fn to_block(
87        &self,
88        imports: &ModuleImportInfo,
89        locals: &BTreeSet<FunctionIdent>,
90    ) -> ast::Block {
91        emit_block(self.body, &self.blocks, imports, locals)
92    }
93
94    /// Create a [Region] from a [miden_assembly::ast::CodeBody] and the set of imports
95    /// and local procedures which will be used to map references to procedures to their
96    /// fully-qualified names.
97    pub fn from_block(current_module: Ident, code: &ast::Block) -> Self {
98        let mut region = Self::default();
99
100        let body = region.body;
101        import_block(current_module, &mut region, body, code);
102
103        region
104    }
105}
106impl core::ops::Index<InstructionPointer> for Region {
107    type Output = Op;
108
109    #[inline]
110    fn index(&self, ip: InstructionPointer) -> &Self::Output {
111        &self.blocks[ip.block].ops[ip.index]
112    }
113}
114
115#[doc(hidden)]
116pub struct DisplayRegion<'a> {
117    region: &'a Region,
118    function: Option<FunctionIdent>,
119    imports: &'a ModuleImportInfo,
120}
121impl<'a> midenc_hir::formatter::PrettyPrint for DisplayRegion<'a> {
122    fn render(&self) -> midenc_hir::formatter::Document {
123        use midenc_hir::DisplayMasmBlock;
124
125        let block = DisplayMasmBlock::new(
126            self.function,
127            Some(self.imports),
128            &self.region.blocks,
129            self.region.body,
130        );
131
132        block.render()
133    }
134}
135impl<'a> fmt::Display for DisplayRegion<'a> {
136    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
137        self.pretty_print(f)
138    }
139}
140
141/// Import code from a [miden_assembly::ast::Block] into the specified [Block] in `region`.
142fn import_block(
143    current_module: Ident,
144    region: &mut Region,
145    current_block_id: BlockId,
146    block: &ast::Block,
147) {
148    for op in block.iter() {
149        match op {
150            ast::Op::Inst(ix) => {
151                let span = ix.span();
152                let current_block = region.block_mut(current_block_id);
153                let ops = Op::from_masm(current_module, (**ix).clone());
154                current_block.extend(ops.into_iter().map(|op| Span::new(span, op)));
155            }
156            ast::Op::If {
157                span,
158                ref then_blk,
159                ref else_blk,
160                ..
161            } => {
162                let then_blk_id = region.create_block();
163                let else_blk_id = region.create_block();
164                import_block(current_module, region, then_blk_id, then_blk);
165                import_block(current_module, region, else_blk_id, else_blk);
166                region.block_mut(current_block_id).push(Op::If(then_blk_id, else_blk_id), *span);
167            }
168            ast::Op::Repeat {
169                span,
170                count,
171                ref body,
172                ..
173            } => {
174                let body_blk = region.create_block();
175                import_block(current_module, region, body_blk, body);
176                let count = u16::try_from(*count).unwrap_or_else(|_| {
177                    panic!("invalid repeat count: expected {count} to be less than 255")
178                });
179                region.block_mut(current_block_id).push(Op::Repeat(count, body_blk), *span);
180            }
181            ast::Op::While { span, ref body, .. } => {
182                let body_blk = region.create_block();
183                import_block(current_module, region, body_blk, body);
184                region.block_mut(current_block_id).push(Op::While(body_blk), *span);
185            }
186        }
187    }
188}
189
190/// Emit a [miden_assembly::ast::CodeBlock] by recursively visiting a tree of blocks
191/// starting with `block_id`, using the provided imports and local/external procedure maps.
192#[allow(clippy::only_used_in_recursion)]
193fn emit_block(
194    block_id: BlockId,
195    blocks: &PrimaryMap<BlockId, Block>,
196    imports: &ModuleImportInfo,
197    locals: &BTreeSet<FunctionIdent>,
198) -> ast::Block {
199    let current_block = &blocks[block_id];
200    let mut ops = Vec::with_capacity(current_block.ops.len());
201    for op in current_block.ops.iter().copied() {
202        let span = op.span();
203        match op.into_inner() {
204            Op::If(then_blk, else_blk) => {
205                let then_blk = emit_block(then_blk, blocks, imports, locals);
206                let else_blk = emit_block(else_blk, blocks, imports, locals);
207                ops.push(ast::Op::If {
208                    span,
209                    then_blk,
210                    else_blk,
211                });
212            }
213            Op::While(blk) => {
214                let body = emit_block(blk, blocks, imports, locals);
215                ops.push(ast::Op::While { span, body });
216            }
217            Op::Repeat(n, blk) => {
218                let body = emit_block(blk, blocks, imports, locals);
219                ops.push(ast::Op::Repeat {
220                    span,
221                    count: n as u32,
222                    body,
223                });
224            }
225            op => {
226                ops.extend(
227                    op.into_masm(imports, locals)
228                        .into_iter()
229                        .map(|inst| ast::Op::Inst(Span::new(span, inst))),
230                );
231            }
232        }
233    }
234
235    ast::Block::new(Default::default(), ops)
236}