midenc_codegen_masm/masm/
module.rs

1use std::{collections::BTreeSet, fmt, path::Path, sync::Arc};
2
3use intrusive_collections::{intrusive_adapter, RBTree, RBTreeAtomicLink};
4use miden_assembly::{
5    ast::{self, ModuleKind},
6    LibraryPath,
7};
8use midenc_hir::{
9    diagnostics::{Report, SourceFile, SourceSpan, Span, Spanned},
10    formatter::PrettyPrint,
11    FunctionIdent, Ident, Symbol,
12};
13
14use super::{function::Functions, FrozenFunctionList, Function, ModuleImportInfo};
15
16/// This represents a single compiled Miden Assembly module in a form that is
17/// designed to integrate well with the rest of our IR. You can think of this
18/// as an intermediate representation corresponding to the Miden Assembly AST,
19/// i.e. [miden_assembly::ast::Module].
20///
21/// Functions are stored in a [Module] in a linked list, so as to allow precise
22/// ordering of functions in the module body. We typically access all of the
23/// functions in a given module, so O(1) access to a specific function is not
24/// of primary importance.
25#[derive(Clone)]
26pub struct Module {
27    link: RBTreeAtomicLink,
28    pub span: SourceSpan,
29    /// The kind of this module, e.g. kernel or library
30    pub kind: ModuleKind,
31    /// The name of this module, e.g. `std::math::u64`
32    pub id: Ident,
33    pub name: LibraryPath,
34    /// The module-scoped documentation for this module
35    pub docs: Option<String>,
36    /// The modules to import, along with their local aliases
37    pub imports: ModuleImportInfo,
38    /// The functions defined in this module
39    functions: Functions,
40    /// The set of re-exported functions declared in this module
41    reexports: Vec<ast::ProcedureAlias>,
42}
43impl Module {
44    /// Create a new, empty [Module] with the given name and kind.
45    pub fn new(name: LibraryPath, kind: ModuleKind) -> Self {
46        let id = Ident::with_empty_span(Symbol::intern(name.path()));
47        Self {
48            link: Default::default(),
49            kind,
50            span: SourceSpan::UNKNOWN,
51            id,
52            name,
53            docs: None,
54            imports: Default::default(),
55            functions: Default::default(),
56            reexports: Default::default(),
57        }
58    }
59
60    /// Parse a [Module] from `source` using the given [ModuleKind] and [LibraryPath]
61    pub fn parse(
62        kind: ModuleKind,
63        path: LibraryPath,
64        source: Arc<SourceFile>,
65    ) -> Result<Self, Report> {
66        let span = source.source_span();
67        let mut parser = ast::Module::parser(kind);
68        let ast = parser.parse(path, source)?;
69        Ok(Self::from_ast(&ast, span))
70    }
71
72    /// Returns true if this module is a kernel module
73    pub fn is_kernel(&self) -> bool {
74        self.kind.is_kernel()
75    }
76
77    /// Returns true if this module is an executable module
78    pub fn is_executable(&self) -> bool {
79        self.kind.is_executable()
80    }
81
82    /// If this module contains a function marked with the `entrypoint` attribute,
83    /// return the fully-qualified name of that function
84    pub fn entrypoint(&self) -> Option<FunctionIdent> {
85        if !self.is_executable() {
86            return None;
87        }
88
89        self.functions.iter().find_map(|f| {
90            if f.is_entrypoint() {
91                Some(f.name)
92            } else {
93                None
94            }
95        })
96    }
97
98    /// Returns true if this module contains a [Function] `name`
99    pub fn contains(&self, name: Ident) -> bool {
100        self.functions.iter().any(|f| f.name.function == name)
101    }
102
103    pub fn from_ast(ast: &ast::Module, span: SourceSpan) -> Self {
104        let mut module = Self::new(ast.path().clone(), ast.kind());
105        module.span = span;
106        module.docs = ast.docs().map(|s| s.to_string());
107
108        let mut imports = ModuleImportInfo::default();
109        for import in ast.imports() {
110            let span = import.name.span();
111            let alias = Symbol::intern(import.name.as_str());
112            let name = if import.is_aliased() {
113                Symbol::intern(import.path.last())
114            } else {
115                alias
116            };
117            imports.insert(midenc_hir::MasmImport { span, name, alias });
118        }
119
120        for export in ast.procedures() {
121            match export {
122                ast::Export::Alias(ref alias) => {
123                    module.reexports.push(alias.clone());
124                }
125                ast::Export::Procedure(ref proc) => {
126                    let function = Function::from_ast(module.id, proc);
127                    module.functions.push_back(function);
128                }
129            }
130        }
131
132        module
133    }
134
135    /// Freezes this program, preventing further modifications
136    pub fn freeze(mut self: Box<Self>) -> Arc<Module> {
137        self.functions.freeze();
138        Arc::from(self)
139    }
140
141    /// Get an iterator over the functions in this module
142    pub fn functions(&self) -> impl Iterator<Item = &Function> + '_ {
143        self.functions.iter()
144    }
145
146    /// Access the frozen functions list of this module, and panic if not frozen
147    pub fn unwrap_frozen_functions(&self) -> &FrozenFunctionList {
148        match self.functions {
149            Functions::Frozen(ref functions) => functions,
150            Functions::Open(_) => panic!("expected module to be frozen"),
151        }
152    }
153
154    /// Append a function to the end of this module
155    ///
156    /// NOTE: This function will panic if the module has been frozen
157    pub fn push_back(&mut self, function: Box<Function>) {
158        self.functions.push_back(function);
159    }
160
161    /// Convert this module into its [miden_assembly::ast::Module] representation.
162    pub fn to_ast(&self, tracing_enabled: bool) -> Result<ast::Module, Report> {
163        let mut ast = ast::Module::new(self.kind, self.name.clone()).with_span(self.span);
164        ast.set_docs(self.docs.clone().map(Span::unknown));
165
166        // Create module import table
167        for ir_import in self.imports.iter() {
168            let span = ir_import.span;
169            let name =
170                ast::Ident::new_with_span(span, ir_import.alias.as_str()).map_err(Report::msg)?;
171            let path = LibraryPath::new(ir_import.name.as_str()).expect("invalid import path");
172            let import = ast::Import {
173                span,
174                name,
175                path,
176                uses: 1,
177            };
178            let _ = ast.define_import(import);
179        }
180
181        // Translate functions
182        let locals = BTreeSet::from_iter(self.functions.iter().map(|f| f.name));
183
184        for reexport in self.reexports.iter() {
185            ast.define_procedure(ast::Export::Alias(reexport.clone()))?;
186        }
187
188        for function in self.functions.iter() {
189            ast.define_procedure(ast::Export::Procedure(function.to_ast(
190                &self.imports,
191                &locals,
192                tracing_enabled,
193            )))?;
194        }
195
196        Ok(ast)
197    }
198
199    /// Write this module to a new file under `dir`, assuming `dir` is the root directory for a
200    /// program.
201    ///
202    /// For example, if this module is named `std::math::u64`, then it will be written to
203    /// `<dir>/std/math/u64.masm`
204    pub fn write_to_directory<P: AsRef<Path>>(
205        &self,
206        dir: P,
207        session: &midenc_session::Session,
208    ) -> std::io::Result<()> {
209        use midenc_session::{Emit, OutputMode};
210
211        let mut path = dir.as_ref().to_path_buf();
212        assert!(path.is_dir());
213        for component in self.name.components() {
214            path.push(component.as_ref());
215        }
216        assert!(path.set_extension("masm"));
217
218        let ast = self.to_ast(false).map_err(std::io::Error::other)?;
219        ast.write_to_file(&path, OutputMode::Text, session)
220    }
221}
222impl midenc_hir::formatter::PrettyPrint for Module {
223    fn render(&self) -> midenc_hir::formatter::Document {
224        use midenc_hir::formatter::*;
225
226        let mut doc = Document::Empty;
227        if let Some(docs) = self.docs.as_ref() {
228            let fragment =
229                docs.lines().map(text).reduce(|acc, line| acc + nl() + text("#! ") + line);
230
231            if let Some(fragment) = fragment {
232                doc += fragment;
233            }
234        }
235
236        for (i, import) in self.imports.iter().enumerate() {
237            if i > 0 {
238                doc += nl();
239            }
240            if import.is_aliased() {
241                doc += flatten(
242                    const_text("use")
243                        + const_text(".")
244                        + text(format!("{}", import.name))
245                        + const_text("->")
246                        + text(format!("{}", import.alias)),
247                );
248            } else {
249                doc +=
250                    flatten(const_text("use") + const_text(".") + text(format!("{}", import.name)));
251            }
252        }
253
254        if !self.imports.is_empty() {
255            doc += nl() + nl();
256        }
257
258        for (i, export) in self.reexports.iter().enumerate() {
259            if i > 0 {
260                doc += nl();
261            }
262            doc += export.render();
263        }
264
265        if !self.reexports.is_empty() {
266            doc += nl() + nl();
267        }
268
269        for (i, func) in self.functions.iter().enumerate() {
270            if i > 0 {
271                doc += nl();
272            }
273            let func = func.display(&self.imports);
274            doc += func.render();
275        }
276
277        doc
278    }
279}
280impl fmt::Display for Module {
281    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
282        self.pretty_print(f)
283    }
284}
285impl midenc_session::Emit for Module {
286    fn name(&self) -> Option<Symbol> {
287        Some(self.id.as_symbol())
288    }
289
290    fn output_type(&self, _mode: midenc_session::OutputMode) -> midenc_session::OutputType {
291        midenc_session::OutputType::Masm
292    }
293
294    fn write_to<W: std::io::Write>(
295        &self,
296        writer: W,
297        mode: midenc_session::OutputMode,
298        session: &midenc_session::Session,
299    ) -> std::io::Result<()> {
300        let ast = self.to_ast(false).map_err(std::io::Error::other)?;
301        ast.write_to(writer, mode, session)
302    }
303}
304
305intrusive_adapter!(pub ModuleTreeAdapter = Box<Module>: Module { link: RBTreeAtomicLink });
306intrusive_adapter!(pub FrozenModuleTreeAdapter = Arc<Module>: Module { link: RBTreeAtomicLink });
307impl<'a> intrusive_collections::KeyAdapter<'a> for ModuleTreeAdapter {
308    type Key = Ident;
309
310    #[inline]
311    fn get_key(&self, module: &'a Module) -> Ident {
312        module.id
313    }
314}
315impl<'a> intrusive_collections::KeyAdapter<'a> for FrozenModuleTreeAdapter {
316    type Key = Ident;
317
318    #[inline]
319    fn get_key(&self, module: &'a Module) -> Ident {
320        module.id
321    }
322}
323
324pub type ModuleTree = RBTree<ModuleTreeAdapter>;
325pub type ModuleTreeIter<'a> = intrusive_collections::rbtree::Iter<'a, ModuleTreeAdapter>;
326
327pub type FrozenModuleTree = RBTree<FrozenModuleTreeAdapter>;
328pub type FrozenModuleTreeIter<'a> =
329    intrusive_collections::rbtree::Iter<'a, FrozenModuleTreeAdapter>;
330
331pub(super) enum Modules {
332    Open(ModuleTree),
333    Frozen(FrozenModuleTree),
334}
335impl Default for Modules {
336    fn default() -> Self {
337        Self::Open(Default::default())
338    }
339}
340impl Clone for Modules {
341    fn clone(&self) -> Self {
342        let mut out = ModuleTree::default();
343        for module in self.iter() {
344            out.insert(Box::new(module.clone()));
345        }
346        Self::Open(out)
347    }
348}
349impl Modules {
350    pub fn len(&self) -> usize {
351        match self {
352            Self::Open(ref tree) => tree.iter().count(),
353            Self::Frozen(ref tree) => tree.iter().count(),
354        }
355    }
356
357    pub fn iter(&self) -> impl Iterator<Item = &Module> + '_ {
358        match self {
359            Self::Open(ref tree) => ModulesIter::Open(tree.iter()),
360            Self::Frozen(ref tree) => ModulesIter::Frozen(tree.iter()),
361        }
362    }
363
364    pub fn get<Q>(&self, name: &Q) -> Option<&Module>
365    where
366        Q: ?Sized + Ord,
367        Ident: core::borrow::Borrow<Q>,
368    {
369        match self {
370            Self::Open(ref tree) => tree.find(name).get(),
371            Self::Frozen(ref tree) => tree.find(name).get(),
372        }
373    }
374
375    pub fn insert(&mut self, module: Box<Module>) {
376        match self {
377            Self::Open(ref mut tree) => {
378                tree.insert(module);
379            }
380            Self::Frozen(_) => panic!("cannot insert module into frozen program"),
381        }
382    }
383
384    pub fn freeze(&mut self) {
385        if let Self::Open(ref mut modules) = self {
386            let mut frozen = FrozenModuleTree::default();
387
388            let mut open = modules.front_mut();
389            while let Some(module) = open.remove() {
390                frozen.insert(module.freeze());
391            }
392
393            *self = Self::Frozen(frozen);
394        }
395    }
396}
397
398enum ModulesIter<'a> {
399    Open(ModuleTreeIter<'a>),
400    Frozen(FrozenModuleTreeIter<'a>),
401}
402impl<'a> Iterator for ModulesIter<'a> {
403    type Item = &'a Module;
404
405    fn next(&mut self) -> Option<Self::Item> {
406        match self {
407            Self::Open(ref mut iter) => iter.next(),
408            Self::Frozen(ref mut iter) => iter.next(),
409        }
410    }
411}