midenc_codegen_masm/masm/
function.rs1use std::{collections::BTreeSet, fmt, sync::Arc};
2
3use cranelift_entity::EntityRef;
4use intrusive_collections::{intrusive_adapter, LinkedList, LinkedListAtomicLink};
5use miden_assembly::{
6 ast::{self, ProcedureName},
7 LibraryNamespace, LibraryPath,
8};
9use midenc_hir::{
10 diagnostics::{SourceSpan, Span, Spanned},
11 formatter::PrettyPrint,
12 AttributeSet, FunctionIdent, Ident, Signature, Type,
13};
14use smallvec::SmallVec;
15
16use super::*;
17
18intrusive_adapter!(pub FunctionListAdapter = Box<Function>: Function { link: LinkedListAtomicLink });
19intrusive_adapter!(pub FrozenFunctionListAdapter = Arc<Function>: Function { link: LinkedListAtomicLink });
20
21#[derive(Spanned, Clone)]
23pub struct Function {
24 link: LinkedListAtomicLink,
25 #[span]
26 pub span: SourceSpan,
27 pub attrs: AttributeSet,
29 pub name: FunctionIdent,
31 pub signature: Signature,
33 pub body: Region,
35 invoked: BTreeSet<ast::Invoke>,
37 locals: SmallVec<[Local; 1]>,
39 next_local_id: usize,
41}
42impl Function {
43 pub fn new(name: FunctionIdent, signature: Signature) -> Self {
44 Self {
45 link: Default::default(),
46 span: SourceSpan::UNKNOWN,
47 attrs: Default::default(),
48 name,
49 signature,
50 body: Default::default(),
51 invoked: Default::default(),
52 locals: Default::default(),
53 next_local_id: 0,
54 }
55 }
56
57 pub fn is_entrypoint(&self) -> bool {
59 use midenc_hir::symbols;
60
61 self.attrs.has(&symbols::Entrypoint)
62 }
63
64 #[inline]
66 pub fn arity(&self) -> usize {
67 self.signature.arity()
68 }
69
70 #[inline]
72 pub fn num_results(&self) -> usize {
73 self.signature.results.len()
74 }
75
76 pub fn alloc_local(&mut self, ty: Type) -> LocalId {
80 let num_words = ty.size_in_words();
81 let next_id = self.next_local_id;
82 assert!(
83 (next_id + num_words) < (u8::MAX as usize),
84 "unable to allocate a local of type {}: unable to allocate enough local memory",
85 &ty
86 );
87 let id = LocalId::new(next_id);
88 self.next_local_id += num_words;
89 let local = Local { id, ty };
90 self.locals.push(local);
91 id
92 }
93
94 pub fn alloc_n_locals(&mut self, n: u16) {
98 assert!(
99 (self.next_local_id + n as usize) < u16::MAX as usize,
100 "unable to allocate {n} locals"
101 );
102
103 let num_locals = self.locals.len();
104 self.locals.resize_with(num_locals + n as usize, || {
105 let id = LocalId::new(self.next_local_id);
106 self.next_local_id += 1;
107 Local { id, ty: Type::Felt }
108 });
109 }
110
111 pub fn local(&self, id: LocalId) -> &Local {
113 self.locals.iter().find(|l| l.id == id).expect("invalid local id")
114 }
115
116 #[inline]
118 pub fn locals(&self) -> &[Local] {
119 self.locals.as_slice()
120 }
121
122 pub fn body(&self) -> &Block {
124 self.body.block(self.body.body)
125 }
126
127 pub fn body_mut(&mut self) -> &mut Block {
129 self.body.block_mut(self.body.body)
130 }
131
132 #[inline(always)]
134 pub fn create_block(&mut self) -> BlockId {
135 self.body.create_block()
136 }
137
138 #[inline(always)]
140 pub fn block(&self, id: BlockId) -> &Block {
141 self.body.block(id)
142 }
143
144 #[inline(always)]
146 pub fn block_mut(&mut self, id: BlockId) -> &mut Block {
147 self.body.block_mut(id)
148 }
149
150 pub fn invoked(&self) -> impl Iterator<Item = &ast::Invoke> + '_ {
151 self.invoked.iter()
152 }
153
154 pub fn register_invoked(&mut self, kind: ast::InvokeKind, target: ast::InvocationTarget) {
155 self.invoked.insert(ast::Invoke { kind, target });
156 }
157
158 #[inline(never)]
159 pub fn register_absolute_invocation_target(
160 &mut self,
161 kind: ast::InvokeKind,
162 target: FunctionIdent,
163 ) {
164 let module_name_span = target.module.span;
165 let module_id = ast::Ident::new_unchecked(Span::new(
166 module_name_span,
167 Arc::from(target.module.as_str().to_string().into_boxed_str()),
168 ));
169 let name_span = target.function.span;
170 let id = ast::Ident::new_unchecked(Span::new(
171 name_span,
172 Arc::from(target.function.as_str().to_string().into_boxed_str()),
173 ));
174 let path = LibraryPath::new(target.module.as_str()).unwrap_or_else(|_| {
175 LibraryPath::new_from_components(LibraryNamespace::Anon, [module_id])
176 });
177 let name = ast::ProcedureName::new_unchecked(id);
178 self.register_invoked(kind, ast::InvocationTarget::AbsoluteProcedurePath { name, path });
179 }
180
181 pub fn display<'a, 'b: 'a>(&'b self, imports: &'b ModuleImportInfo) -> DisplayMasmFunction<'a> {
183 DisplayMasmFunction {
184 function: self,
185 imports,
186 }
187 }
188
189 pub fn from_ast(module: Ident, proc: &ast::Procedure) -> Box<Self> {
190 use midenc_hir::{Linkage, Symbol};
191
192 let proc_span = proc.name().span();
193 let proc_name = Symbol::intern(AsRef::<str>::as_ref(proc.name()));
194 let id = FunctionIdent {
195 module,
196 function: Ident::new(proc_name, proc_span),
197 };
198
199 let mut signature = Signature::new(vec![], vec![]);
200 let visibility = proc.visibility();
201 if !visibility.is_exported() {
202 signature.linkage = Linkage::Internal;
203 } else if visibility.is_syscall() {
204 signature.cc = midenc_hir::CallConv::Kernel;
205 }
206
207 let mut function = Box::new(Self::new(id, signature));
208 if proc.is_entrypoint() {
209 function.attrs.set(midenc_hir::attributes::ENTRYPOINT);
210 }
211
212 function.alloc_n_locals(proc.num_locals());
213
214 function.invoked.extend(proc.invoked().cloned());
215 function.body = Region::from_block(module, proc.body());
216
217 function
218 }
219
220 pub fn to_ast(
221 &self,
222 imports: &midenc_hir::ModuleImportInfo,
223 locals: &BTreeSet<FunctionIdent>,
224 tracing_enabled: bool,
225 ) -> ast::Procedure {
226 let visibility = if self.signature.is_kernel() {
227 ast::Visibility::Syscall
228 } else if self.signature.is_public() {
229 ast::Visibility::Public
230 } else {
231 ast::Visibility::Private
232 };
233
234 let id = ast::Ident::new_unchecked(Span::new(
235 self.name.function.span,
236 Arc::from(self.name.function.as_str().to_string().into_boxed_str()),
237 ));
238 let name = ast::ProcedureName::new_unchecked(id);
239
240 let mut body = self.body.to_block(imports, locals);
241
242 if tracing_enabled {
244 emit_trace_frame_events(self.span, &mut body);
245 }
246
247 let num_locals = u16::try_from(self.locals.len()).expect("too many locals");
248 let mut proc = ast::Procedure::new(self.span, visibility, name, num_locals, body);
249 proc.extend_invoked(self.invoked().cloned());
250 proc
251 }
252}
253
254fn emit_trace_frame_events(span: SourceSpan, body: &mut ast::Block) {
255 use midenc_hir::{TRACE_FRAME_END, TRACE_FRAME_START};
256
257 let ops = body.iter().as_slice();
258 let has_frame_start = match ops.get(1) {
259 Some(ast::Op::Inst(inst)) => match inst.inner() {
260 ast::Instruction::Trace(imm) => {
261 matches!(imm, ast::Immediate::Value(val) if val.into_inner() == TRACE_FRAME_START)
262 }
263 _ => false,
264 },
265 _ => false,
266 };
267
268 if has_frame_start {
270 return;
271 }
272
273 body.push(ast::Op::Inst(Span::new(span, ast::Instruction::Nop)));
277 body.push(ast::Op::Inst(Span::new(span, ast::Instruction::Trace(TRACE_FRAME_END.into()))));
278 body.push(ast::Op::Inst(Span::new(span, ast::Instruction::Nop)));
279 body.push(ast::Op::Inst(Span::new(
280 span,
281 ast::Instruction::Trace(TRACE_FRAME_START.into()),
282 )));
283 let ops = body.iter_mut().into_slice();
284 ops.rotate_right(2);
285}
286
287impl fmt::Debug for Function {
288 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
289 f.debug_struct("Function")
290 .field("name", &self.name)
291 .field("signature", &self.signature)
292 .field("attrs", &self.attrs)
293 .field("locals", &self.locals)
294 .field("body", &self.body)
295 .finish()
296 }
297}
298
299#[doc(hidden)]
300pub struct DisplayMasmFunction<'a> {
301 function: &'a Function,
302 imports: &'a ModuleImportInfo,
303}
304impl<'a> midenc_hir::formatter::PrettyPrint for DisplayMasmFunction<'a> {
305 fn render(&self) -> midenc_hir::formatter::Document {
306 use midenc_hir::formatter::*;
307
308 if self.function.name.module.as_str() == LibraryNamespace::EXEC_PATH
309 && self.function.name.function.as_str() == ProcedureName::MAIN_PROC_NAME
310 {
311 let body = self.function.body.display(Some(self.function.name), self.imports);
312 return indent(4, const_text("begin") + nl() + body.render())
313 + nl()
314 + const_text("end")
315 + nl();
316 }
317
318 let visibility = if self.function.signature.is_kernel() {
319 ast::Visibility::Syscall
320 } else if self.function.signature.is_public() {
321 ast::Visibility::Public
322 } else {
323 ast::Visibility::Private
324 };
325 let name = if ast::Ident::validate(self.function.name.function).is_ok() {
326 text(self.function.name.function.as_str())
327 } else {
328 text(format!("\"{}\"", self.function.name.function.as_str()))
329 };
330 let mut doc = display(visibility) + const_text(".") + name;
331 if !self.function.locals.is_empty() {
332 doc += const_text(".") + display(self.function.locals.len());
333 }
334
335 let body = self.function.body.display(Some(self.function.name), self.imports);
336 doc + indent(4, nl() + body.render()) + nl() + const_text("end") + nl() + nl()
337 }
338}
339impl<'a> fmt::Display for DisplayMasmFunction<'a> {
340 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
341 self.pretty_print(f)
342 }
343}
344
345pub type FunctionList = LinkedList<FunctionListAdapter>;
346pub type FunctionListIter<'a> = intrusive_collections::linked_list::Iter<'a, FunctionListAdapter>;
347
348pub type FrozenFunctionList = LinkedList<FrozenFunctionListAdapter>;
349pub type FrozenFunctionListIter<'a> =
350 intrusive_collections::linked_list::Iter<'a, FrozenFunctionListAdapter>;
351
352pub(super) enum Functions {
353 Open(FunctionList),
354 Frozen(FrozenFunctionList),
355}
356impl Clone for Functions {
357 fn clone(&self) -> Self {
358 match self {
359 Self::Open(list) => {
360 let mut new_list = FunctionList::new(Default::default());
361 for f in list.iter() {
362 new_list.push_back(Box::new(f.clone()));
363 }
364 Self::Open(new_list)
365 }
366 Self::Frozen(list) => {
367 let mut new_list = FrozenFunctionList::new(Default::default());
368 for f in list.iter() {
369 new_list.push_back(Arc::from(Box::new(f.clone())));
370 }
371 Self::Frozen(new_list)
372 }
373 }
374 }
375}
376impl Default for Functions {
377 fn default() -> Self {
378 Self::Open(Default::default())
379 }
380}
381impl Functions {
382 pub fn iter(&self) -> impl Iterator<Item = &Function> + '_ {
383 match self {
384 Self::Open(ref list) => FunctionsIter::Open(list.iter()),
385 Self::Frozen(ref list) => FunctionsIter::Frozen(list.iter()),
386 }
387 }
388
389 pub fn push_back(&mut self, function: Box<Function>) {
390 match self {
391 Self::Open(ref mut list) => {
392 list.push_back(function);
393 }
394 Self::Frozen(_) => panic!("cannot insert function into frozen module"),
395 }
396 }
397
398 pub fn freeze(&mut self) {
399 if let Self::Open(ref mut functions) = self {
400 let mut frozen = FrozenFunctionList::default();
401
402 while let Some(function) = functions.pop_front() {
403 frozen.push_back(Arc::from(function));
404 }
405
406 *self = Self::Frozen(frozen);
407 }
408 }
409}
410
411enum FunctionsIter<'a> {
412 Open(FunctionListIter<'a>),
413 Frozen(FrozenFunctionListIter<'a>),
414}
415impl<'a> Iterator for FunctionsIter<'a> {
416 type Item = &'a Function;
417
418 fn next(&mut self) -> Option<Self::Item> {
419 match self {
420 Self::Open(ref mut iter) => iter.next(),
421 Self::Frozen(ref mut iter) => iter.next(),
422 }
423 }
424}