midenc_codegen_masm/masm/
function.rs

1use std::{collections::BTreeSet, fmt, sync::Arc};
2
3use cranelift_entity::EntityRef;
4use intrusive_collections::{intrusive_adapter, LinkedList, LinkedListAtomicLink};
5use miden_assembly::{
6    ast::{self, ProcedureName},
7    LibraryNamespace, LibraryPath,
8};
9use midenc_hir::{
10    diagnostics::{SourceSpan, Span, Spanned},
11    formatter::PrettyPrint,
12    AttributeSet, FunctionIdent, Ident, Signature, Type,
13};
14use smallvec::SmallVec;
15
16use super::*;
17
18intrusive_adapter!(pub FunctionListAdapter = Box<Function>: Function { link: LinkedListAtomicLink });
19intrusive_adapter!(pub FrozenFunctionListAdapter = Arc<Function>: Function { link: LinkedListAtomicLink });
20
21/// This represents a function in Miden Assembly
22#[derive(Spanned, Clone)]
23pub struct Function {
24    link: LinkedListAtomicLink,
25    #[span]
26    pub span: SourceSpan,
27    /// The attributes associated with this function
28    pub attrs: AttributeSet,
29    /// The name of this function
30    pub name: FunctionIdent,
31    /// The type signature of this function
32    pub signature: Signature,
33    /// The [Region] which forms the body of this function
34    pub body: Region,
35    /// The set of procedures invoked from the body of this function
36    invoked: BTreeSet<ast::Invoke>,
37    /// Locals allocated for this function
38    locals: SmallVec<[Local; 1]>,
39    /// The next available local index
40    next_local_id: usize,
41}
42impl Function {
43    pub fn new(name: FunctionIdent, signature: Signature) -> Self {
44        Self {
45            link: Default::default(),
46            span: SourceSpan::UNKNOWN,
47            attrs: Default::default(),
48            name,
49            signature,
50            body: Default::default(),
51            invoked: Default::default(),
52            locals: Default::default(),
53            next_local_id: 0,
54        }
55    }
56
57    /// Returns true if this function is decorated with the `entrypoint` attribute.
58    pub fn is_entrypoint(&self) -> bool {
59        use midenc_hir::symbols;
60
61        self.attrs.has(&symbols::Entrypoint)
62    }
63
64    /// Return the number of arguments expected on the operand stack
65    #[inline]
66    pub fn arity(&self) -> usize {
67        self.signature.arity()
68    }
69
70    /// Return the number of results produced by this function
71    #[inline]
72    pub fn num_results(&self) -> usize {
73        self.signature.results.len()
74    }
75
76    /// Allocate a new local in this function, using the provided data
77    ///
78    /// The index of the local is returned as it's identifier
79    pub fn alloc_local(&mut self, ty: Type) -> LocalId {
80        let num_words = ty.size_in_words();
81        let next_id = self.next_local_id;
82        assert!(
83            (next_id + num_words) < (u8::MAX as usize),
84            "unable to allocate a local of type {}: unable to allocate enough local memory",
85            &ty
86        );
87        let id = LocalId::new(next_id);
88        self.next_local_id += num_words;
89        let local = Local { id, ty };
90        self.locals.push(local);
91        id
92    }
93
94    /// Allocate `n` locals for use by this function.
95    ///
96    /// Each local can be independently accessed, but they are all of type `Felt`
97    pub fn alloc_n_locals(&mut self, n: u16) {
98        assert!(
99            (self.next_local_id + n as usize) < u16::MAX as usize,
100            "unable to allocate {n} locals"
101        );
102
103        let num_locals = self.locals.len();
104        self.locals.resize_with(num_locals + n as usize, || {
105            let id = LocalId::new(self.next_local_id);
106            self.next_local_id += 1;
107            Local { id, ty: Type::Felt }
108        });
109    }
110
111    /// Get the local with the given identifier
112    pub fn local(&self, id: LocalId) -> &Local {
113        self.locals.iter().find(|l| l.id == id).expect("invalid local id")
114    }
115
116    /// Return the locals allocated in this function as a slice
117    #[inline]
118    pub fn locals(&self) -> &[Local] {
119        self.locals.as_slice()
120    }
121
122    /// Get a reference to the entry block for this function
123    pub fn body(&self) -> &Block {
124        self.body.block(self.body.body)
125    }
126
127    /// Get a mutable reference to the entry block for this function
128    pub fn body_mut(&mut self) -> &mut Block {
129        self.body.block_mut(self.body.body)
130    }
131
132    /// Allocate a new code block in this function
133    #[inline(always)]
134    pub fn create_block(&mut self) -> BlockId {
135        self.body.create_block()
136    }
137
138    /// Get a reference to a [Block] by [BlockId]
139    #[inline(always)]
140    pub fn block(&self, id: BlockId) -> &Block {
141        self.body.block(id)
142    }
143
144    /// Get a mutable reference to a [Block] by [BlockId]
145    #[inline(always)]
146    pub fn block_mut(&mut self, id: BlockId) -> &mut Block {
147        self.body.block_mut(id)
148    }
149
150    pub fn invoked(&self) -> impl Iterator<Item = &ast::Invoke> + '_ {
151        self.invoked.iter()
152    }
153
154    pub fn register_invoked(&mut self, kind: ast::InvokeKind, target: ast::InvocationTarget) {
155        self.invoked.insert(ast::Invoke { kind, target });
156    }
157
158    #[inline(never)]
159    pub fn register_absolute_invocation_target(
160        &mut self,
161        kind: ast::InvokeKind,
162        target: FunctionIdent,
163    ) {
164        let module_name_span = target.module.span;
165        let module_id = ast::Ident::new_unchecked(Span::new(
166            module_name_span,
167            Arc::from(target.module.as_str().to_string().into_boxed_str()),
168        ));
169        let name_span = target.function.span;
170        let id = ast::Ident::new_unchecked(Span::new(
171            name_span,
172            Arc::from(target.function.as_str().to_string().into_boxed_str()),
173        ));
174        let path = LibraryPath::new(target.module.as_str()).unwrap_or_else(|_| {
175            LibraryPath::new_from_components(LibraryNamespace::Anon, [module_id])
176        });
177        let name = ast::ProcedureName::new_unchecked(id);
178        self.register_invoked(kind, ast::InvocationTarget::AbsoluteProcedurePath { name, path });
179    }
180
181    /// Return an implementation of [std::fmt::Display] for this function
182    pub fn display<'a, 'b: 'a>(&'b self, imports: &'b ModuleImportInfo) -> DisplayMasmFunction<'a> {
183        DisplayMasmFunction {
184            function: self,
185            imports,
186        }
187    }
188
189    pub fn from_ast(module: Ident, proc: &ast::Procedure) -> Box<Self> {
190        use midenc_hir::{Linkage, Symbol};
191
192        let proc_span = proc.name().span();
193        let proc_name = Symbol::intern(AsRef::<str>::as_ref(proc.name()));
194        let id = FunctionIdent {
195            module,
196            function: Ident::new(proc_name, proc_span),
197        };
198
199        let mut signature = Signature::new(vec![], vec![]);
200        let visibility = proc.visibility();
201        if !visibility.is_exported() {
202            signature.linkage = Linkage::Internal;
203        } else if visibility.is_syscall() {
204            signature.cc = midenc_hir::CallConv::Kernel;
205        }
206
207        let mut function = Box::new(Self::new(id, signature));
208        if proc.is_entrypoint() {
209            function.attrs.set(midenc_hir::attributes::ENTRYPOINT);
210        }
211
212        function.alloc_n_locals(proc.num_locals());
213
214        function.invoked.extend(proc.invoked().cloned());
215        function.body = Region::from_block(module, proc.body());
216
217        function
218    }
219
220    pub fn to_ast(
221        &self,
222        imports: &midenc_hir::ModuleImportInfo,
223        locals: &BTreeSet<FunctionIdent>,
224        tracing_enabled: bool,
225    ) -> ast::Procedure {
226        let visibility = if self.signature.is_kernel() {
227            ast::Visibility::Syscall
228        } else if self.signature.is_public() {
229            ast::Visibility::Public
230        } else {
231            ast::Visibility::Private
232        };
233
234        let id = ast::Ident::new_unchecked(Span::new(
235            self.name.function.span,
236            Arc::from(self.name.function.as_str().to_string().into_boxed_str()),
237        ));
238        let name = ast::ProcedureName::new_unchecked(id);
239
240        let mut body = self.body.to_block(imports, locals);
241
242        // Emit trace events on entry/exit from the procedure body, if not already present
243        if tracing_enabled {
244            emit_trace_frame_events(self.span, &mut body);
245        }
246
247        let num_locals = u16::try_from(self.locals.len()).expect("too many locals");
248        let mut proc = ast::Procedure::new(self.span, visibility, name, num_locals, body);
249        proc.extend_invoked(self.invoked().cloned());
250        proc
251    }
252}
253
254fn emit_trace_frame_events(span: SourceSpan, body: &mut ast::Block) {
255    use midenc_hir::{TRACE_FRAME_END, TRACE_FRAME_START};
256
257    let ops = body.iter().as_slice();
258    let has_frame_start = match ops.get(1) {
259        Some(ast::Op::Inst(inst)) => match inst.inner() {
260            ast::Instruction::Trace(imm) => {
261                matches!(imm, ast::Immediate::Value(val) if val.into_inner() == TRACE_FRAME_START)
262            }
263            _ => false,
264        },
265        _ => false,
266    };
267
268    // If we have the frame start event, we do not need to emit any further events
269    if has_frame_start {
270        return;
271    }
272
273    // Because [ast::Block] does not have a mutator that lets us insert an op at the start, we need
274    // to push the events at the end, then use access to the mutable slice via `iter_mut` to move
275    // elements around.
276    body.push(ast::Op::Inst(Span::new(span, ast::Instruction::Nop)));
277    body.push(ast::Op::Inst(Span::new(span, ast::Instruction::Trace(TRACE_FRAME_END.into()))));
278    body.push(ast::Op::Inst(Span::new(span, ast::Instruction::Nop)));
279    body.push(ast::Op::Inst(Span::new(
280        span,
281        ast::Instruction::Trace(TRACE_FRAME_START.into()),
282    )));
283    let ops = body.iter_mut().into_slice();
284    ops.rotate_right(2);
285}
286
287impl fmt::Debug for Function {
288    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
289        f.debug_struct("Function")
290            .field("name", &self.name)
291            .field("signature", &self.signature)
292            .field("attrs", &self.attrs)
293            .field("locals", &self.locals)
294            .field("body", &self.body)
295            .finish()
296    }
297}
298
299#[doc(hidden)]
300pub struct DisplayMasmFunction<'a> {
301    function: &'a Function,
302    imports: &'a ModuleImportInfo,
303}
304impl<'a> midenc_hir::formatter::PrettyPrint for DisplayMasmFunction<'a> {
305    fn render(&self) -> midenc_hir::formatter::Document {
306        use midenc_hir::formatter::*;
307
308        if self.function.name.module.as_str() == LibraryNamespace::EXEC_PATH
309            && self.function.name.function.as_str() == ProcedureName::MAIN_PROC_NAME
310        {
311            let body = self.function.body.display(Some(self.function.name), self.imports);
312            return indent(4, const_text("begin") + nl() + body.render())
313                + nl()
314                + const_text("end")
315                + nl();
316        }
317
318        let visibility = if self.function.signature.is_kernel() {
319            ast::Visibility::Syscall
320        } else if self.function.signature.is_public() {
321            ast::Visibility::Public
322        } else {
323            ast::Visibility::Private
324        };
325        let name = if ast::Ident::validate(self.function.name.function).is_ok() {
326            text(self.function.name.function.as_str())
327        } else {
328            text(format!("\"{}\"", self.function.name.function.as_str()))
329        };
330        let mut doc = display(visibility) + const_text(".") + name;
331        if !self.function.locals.is_empty() {
332            doc += const_text(".") + display(self.function.locals.len());
333        }
334
335        let body = self.function.body.display(Some(self.function.name), self.imports);
336        doc + indent(4, nl() + body.render()) + nl() + const_text("end") + nl() + nl()
337    }
338}
339impl<'a> fmt::Display for DisplayMasmFunction<'a> {
340    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
341        self.pretty_print(f)
342    }
343}
344
345pub type FunctionList = LinkedList<FunctionListAdapter>;
346pub type FunctionListIter<'a> = intrusive_collections::linked_list::Iter<'a, FunctionListAdapter>;
347
348pub type FrozenFunctionList = LinkedList<FrozenFunctionListAdapter>;
349pub type FrozenFunctionListIter<'a> =
350    intrusive_collections::linked_list::Iter<'a, FrozenFunctionListAdapter>;
351
352pub(super) enum Functions {
353    Open(FunctionList),
354    Frozen(FrozenFunctionList),
355}
356impl Clone for Functions {
357    fn clone(&self) -> Self {
358        match self {
359            Self::Open(list) => {
360                let mut new_list = FunctionList::new(Default::default());
361                for f in list.iter() {
362                    new_list.push_back(Box::new(f.clone()));
363                }
364                Self::Open(new_list)
365            }
366            Self::Frozen(list) => {
367                let mut new_list = FrozenFunctionList::new(Default::default());
368                for f in list.iter() {
369                    new_list.push_back(Arc::from(Box::new(f.clone())));
370                }
371                Self::Frozen(new_list)
372            }
373        }
374    }
375}
376impl Default for Functions {
377    fn default() -> Self {
378        Self::Open(Default::default())
379    }
380}
381impl Functions {
382    pub fn iter(&self) -> impl Iterator<Item = &Function> + '_ {
383        match self {
384            Self::Open(ref list) => FunctionsIter::Open(list.iter()),
385            Self::Frozen(ref list) => FunctionsIter::Frozen(list.iter()),
386        }
387    }
388
389    pub fn push_back(&mut self, function: Box<Function>) {
390        match self {
391            Self::Open(ref mut list) => {
392                list.push_back(function);
393            }
394            Self::Frozen(_) => panic!("cannot insert function into frozen module"),
395        }
396    }
397
398    pub fn freeze(&mut self) {
399        if let Self::Open(ref mut functions) = self {
400            let mut frozen = FrozenFunctionList::default();
401
402            while let Some(function) = functions.pop_front() {
403                frozen.push_back(Arc::from(function));
404            }
405
406            *self = Self::Frozen(frozen);
407        }
408    }
409}
410
411enum FunctionsIter<'a> {
412    Open(FunctionListIter<'a>),
413    Frozen(FrozenFunctionListIter<'a>),
414}
415impl<'a> Iterator for FunctionsIter<'a> {
416    type Item = &'a Function;
417
418    fn next(&mut self) -> Option<Self::Item> {
419        match self {
420            Self::Open(ref mut iter) => iter.next(),
421            Self::Frozen(ref mut iter) => iter.next(),
422        }
423    }
424}