1use alloc::collections::BTreeMap;
2
3use intrusive_collections::{
4 intrusive_adapter,
5 linked_list::{Cursor, CursorMut},
6 LinkedList, LinkedListLink, RBTreeLink,
7};
8use rustc_hash::FxHashSet;
9
10use self::formatter::PrettyPrint;
11use crate::{
12 diagnostics::{miette, Diagnostic, DiagnosticsHandler, Report, Severity, Spanned},
13 *,
14};
15
16#[derive(Debug, thiserror::Error, Diagnostic)]
18#[error("module {} has already been declared", .name)]
19#[diagnostic()]
20pub struct ModuleConflictError {
21 #[label("duplicate declaration occurs here")]
22 pub span: SourceSpan,
23 pub name: Symbol,
24}
25impl ModuleConflictError {
26 pub fn new(name: Ident) -> Self {
27 Self {
28 span: name.span,
29 name: name.as_symbol(),
30 }
31 }
32}
33
34pub type ModuleTree = intrusive_collections::RBTree<ModuleTreeAdapter>;
35pub type ModuleList = intrusive_collections::LinkedList<ModuleListAdapter>;
36
37intrusive_adapter!(pub ModuleListAdapter = Box<Module>: Module { list_link: LinkedListLink });
38intrusive_adapter!(pub ModuleTreeAdapter = Box<Module>: Module { link: RBTreeLink });
39impl<'a> intrusive_collections::KeyAdapter<'a> for ModuleTreeAdapter {
40 type Key = Ident;
41
42 #[inline]
43 fn get_key(&self, module: &'a Module) -> Ident {
44 module.name
45 }
46}
47
48#[derive(Spanned, AnalysisKey)]
58pub struct Module {
59 link: RBTreeLink,
61 list_link: LinkedListLink,
63 #[span]
65 #[analysis_key]
66 pub name: Ident,
67 pub docs: Option<String>,
70 reserved_memory_pages: u32,
77 page_size: u32,
81 pub(crate) segments: DataSegmentTable,
83 pub(crate) globals: GlobalVariableTable,
85 pub(crate) functions: LinkedList<FunctionListAdapter>,
88 is_kernel: bool,
100}
101impl fmt::Display for Module {
102 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
103 self.pretty_print(f)
104 }
105}
106impl formatter::PrettyPrint for Module {
107 fn render(&self) -> formatter::Document {
108 use crate::formatter::*;
109
110 let mut header =
111 const_text("(") + const_text("module") + const_text(" ") + display(self.name);
112 if self.is_kernel {
113 header += const_text(" ") + const_text("(") + const_text("kernel") + const_text(")");
114 }
115
116 let segments = self
117 .segments
118 .iter()
119 .map(PrettyPrint::render)
120 .reduce(|acc, doc| acc + nl() + doc)
121 .map(|doc| const_text(";; Data Segments") + nl() + doc)
122 .unwrap_or(Document::Empty);
123
124 let constants = self
125 .globals
126 .constants()
127 .map(|(constant, constant_data)| {
128 const_text("(")
129 + const_text("const")
130 + const_text(" ")
131 + const_text("(")
132 + const_text("id")
133 + const_text(" ")
134 + display(constant.as_u32())
135 + const_text(")")
136 + const_text(" ")
137 + text(format!("{:#x}", constant_data.as_ref()))
138 + const_text(")")
139 })
140 .reduce(|acc, doc| acc + nl() + doc)
141 .map(|doc| const_text(";; Constants") + nl() + doc)
142 .unwrap_or(Document::Empty);
143
144 let globals = self
145 .globals
146 .iter()
147 .map(PrettyPrint::render)
148 .reduce(|acc, doc| acc + nl() + doc)
149 .map(|doc| const_text(";; Global Variables") + nl() + doc)
150 .unwrap_or(Document::Empty);
151
152 let mut external_functions = BTreeMap::<FunctionIdent, Signature>::default();
153 let functions = self
154 .functions
155 .iter()
156 .map(|fun| {
157 for import in fun.dfg.imports() {
158 if import.id.module == self.name {
160 continue;
161 }
162 external_functions.entry(import.id).or_insert_with(|| import.signature.clone());
163 }
164 fun.render()
165 })
166 .reduce(|acc, doc| acc + nl() + nl() + doc)
167 .map(|doc| const_text(";; Functions") + nl() + doc)
168 .unwrap_or(Document::Empty);
169
170 let imports = external_functions
171 .into_iter()
172 .map(|(id, signature)| ExternalFunction { id, signature }.render())
173 .reduce(|acc, doc| acc + nl() + doc)
174 .map(|doc| const_text(";; Imports") + nl() + doc)
175 .unwrap_or(Document::Empty);
176
177 let body = vec![segments, constants, globals, functions, imports]
178 .into_iter()
179 .filter(|section| !section.is_empty())
180 .fold(nl(), |a, b| {
181 if matches!(a, Document::Newline) {
182 indent(4, a + b)
183 } else {
184 a + nl() + indent(4, nl() + b)
185 }
186 });
187
188 if body.is_empty() {
189 header + const_text(")") + nl()
190 } else {
191 header + body + nl() + const_text(")") + nl()
192 }
193 }
194}
195impl fmt::Debug for Module {
196 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
197 f.debug_struct("Module")
198 .field("name", &self.name)
199 .field("reserved_memory_pages", &self.reserved_memory_pages)
200 .field("page_size", &self.page_size)
201 .field("is_kernel", &self.is_kernel)
202 .field("docs", &self.docs)
203 .field("segments", &self.segments)
204 .field("globals", &self.globals)
205 .field("functions", &self.functions)
206 .finish()
207 }
208}
209impl midenc_session::Emit for Module {
210 fn name(&self) -> Option<crate::Symbol> {
211 Some(self.name.as_symbol())
212 }
213
214 fn output_type(&self, _mode: midenc_session::OutputMode) -> midenc_session::OutputType {
215 midenc_session::OutputType::Hir
216 }
217
218 fn write_to<W: std::io::Write>(
219 &self,
220 mut writer: W,
221 mode: midenc_session::OutputMode,
222 _session: &midenc_session::Session,
223 ) -> std::io::Result<()> {
224 assert_eq!(
225 mode,
226 midenc_session::OutputMode::Text,
227 "binary mode is not supported for HIR modules"
228 );
229 writer.write_fmt(format_args!("{}", self))
230 }
231}
232impl Eq for Module {}
233impl PartialEq for Module {
234 fn eq(&self, other: &Self) -> bool {
235 let is_eq = self.name == other.name
236 && self.is_kernel == other.is_kernel
237 && self.reserved_memory_pages == other.reserved_memory_pages
238 && self.page_size == other.page_size
239 && self.docs == other.docs
240 && self.segments.iter().eq(other.segments.iter())
241 && self.globals.len() == other.globals.len()
242 && self.functions.iter().count() == other.functions.iter().count();
243 if !is_eq {
244 return false;
245 }
246
247 for global in self.globals.iter() {
248 let id = global.id();
249 if !other.globals.contains_key(id) {
250 return false;
251 }
252 let other_global = other.globals.get(id);
253 if global != other_global {
254 return false;
255 }
256 }
257
258 for function in self.functions.iter() {
259 if !other.contains(function.id.function) {
260 return false;
261 }
262 if let Some(other_function) = other.function(function.id.function) {
263 if function != other_function {
264 return false;
265 }
266 } else {
267 return false;
268 }
269 }
270
271 true
272 }
273}
274
275macro_rules! assert_valid_function {
277 ($module:ident, $function:ident) => {
278 assert_eq!($module.name, $function.id.module, "mismatched module identifiers");
279 assert!(
280 $function.is_detached(),
281 "cannot attach a function to a module that is already attached to a module"
282 );
283 if $function.is_kernel() {
285 assert!($module.is_kernel, "cannot add kernel functions to a non-kernel module");
286 } else if $module.is_kernel && $function.is_public() {
287 panic!(
288 "functions with external linkage in kernel modules must use the kernel calling \
289 convention"
290 );
291 }
292 };
293}
294
295impl Module {
296 pub fn new<S: Into<Ident>>(name: S) -> Self {
298 Self::make(name.into(), false)
299 }
300
301 pub fn new_with_span<S: AsRef<str>>(name: S, span: SourceSpan) -> Self {
303 let name = Ident::new(Symbol::intern(name.as_ref()), span);
304 Self::make(name, false)
305 }
306
307 pub fn new_kernel<S: Into<Ident>>(name: S) -> Self {
309 Self::make(name.into(), true)
310 }
311
312 pub fn new_kernel_with_span<S: AsRef<str>>(name: S, span: SourceSpan) -> Self {
314 let name = Ident::new(Symbol::intern(name.as_ref()), span);
315 Self::make(name, true)
316 }
317
318 fn make(name: Ident, is_kernel: bool) -> Self {
319 Self {
320 link: Default::default(),
321 list_link: Default::default(),
322 name,
323 docs: None,
324 reserved_memory_pages: 0,
325 page_size: 64 * 1024,
326 segments: Default::default(),
327 globals: GlobalVariableTable::new(ConflictResolutionStrategy::None),
328 functions: Default::default(),
329 is_kernel,
330 }
331 }
332
333 #[inline]
335 pub const fn page_size(&self) -> u32 {
336 self.page_size
337 }
338
339 #[inline]
342 pub const fn reserved_memory_pages(&self) -> u32 {
343 self.reserved_memory_pages
344 }
345
346 #[inline]
349 pub const fn reserved_memory_bytes(&self) -> u32 {
350 self.reserved_memory_pages * self.page_size
351 }
352
353 pub fn set_reserved_memory_size(&mut self, size: u32) {
358 self.reserved_memory_pages = size;
359 }
360
361 #[inline]
363 pub const fn is_kernel(&self) -> bool {
364 self.is_kernel
365 }
366
367 pub fn is_detached(&self) -> bool {
369 !self.link.is_linked()
370 }
371
372 pub fn segments(&self) -> &DataSegmentTable {
374 &self.segments
375 }
376
377 pub fn declare_data_segment(
386 &mut self,
387 offset: Offset,
388 size: u32,
389 init: ConstantData,
390 readonly: bool,
391 ) -> Result<(), DataSegmentError> {
392 self.segments.declare(offset, size, init, readonly)
393 }
394
395 pub fn globals(&self) -> &GlobalVariableTable {
397 &self.globals
398 }
399
400 pub fn declare_global_variable(
409 &mut self,
410 name: Ident,
411 ty: Type,
412 linkage: Linkage,
413 init: Option<ConstantData>,
414 ) -> Result<GlobalVariable, GlobalVariableError> {
415 self.globals.declare(name, ty, linkage, init)
416 }
417
418 pub fn set_global_initializer(
423 &mut self,
424 gv: GlobalVariable,
425 init: ConstantData,
426 ) -> Result<(), GlobalVariableError> {
427 self.globals.set_initializer(gv, init)
428 }
429
430 #[inline]
432 pub fn global(&self, id: GlobalVariable) -> &GlobalVariableData {
433 self.globals.get(id)
434 }
435
436 pub fn find_global(&self, name: Ident) -> Option<&GlobalVariableData> {
438 self.globals.find(name).map(|gv| self.globals.get(gv))
439 }
440
441 pub fn entrypoint(&self) -> Option<FunctionIdent> {
443 self.functions.iter().find_map(|f| {
444 if f.has_attribute(&symbols::Entrypoint) {
445 Some(f.id)
446 } else {
447 None
448 }
449 })
450 }
451
452 pub fn functions<'a, 'b: 'a>(
456 &'b self,
457 ) -> intrusive_collections::linked_list::Iter<'a, FunctionListAdapter> {
458 self.functions.iter()
459 }
460
461 pub fn function<'a, 'b: 'a>(&'b self, id: Ident) -> Option<&'a Function> {
463 self.cursor_at(id).get()
464 }
465
466 pub fn imports(&self) -> ModuleImportInfo {
469 let mut imports = ModuleImportInfo::default();
470 let locals = self.functions.iter().map(|f| f.id).collect::<FxHashSet<FunctionIdent>>();
471
472 for function in self.functions.iter() {
473 for import in function.imports() {
474 if !locals.contains(&import.id) {
475 imports.add(import.id);
476 }
477 }
478 }
479 imports
480 }
481
482 pub fn contains(&self, name: Ident) -> bool {
484 self.function(name).is_some()
485 }
486
487 pub fn unlink(&mut self, id: Ident) -> Box<Function> {
489 let mut cursor = self.cursor_mut_at(id);
490 cursor
491 .remove()
492 .unwrap_or_else(|| panic!("cursor pointing to a null when removing function id: {id}"))
493 }
494
495 pub fn push(&mut self, function: Box<Function>) -> Result<(), SymbolConflictError> {
504 assert_valid_function!(self, function);
505 if let Some(prev) = self.function(function.id.function) {
506 return Err(SymbolConflictError(prev.id));
507 }
508 self.functions.push_back(function);
509 Ok(())
510 }
511
512 pub fn insert_before(
517 &mut self,
518 function: Box<Function>,
519 before: Ident,
520 ) -> Result<(), SymbolConflictError> {
521 assert_valid_function!(self, function);
522 if let Some(prev) = self.function(function.id.function) {
523 return Err(SymbolConflictError(prev.id));
524 }
525
526 let mut cursor = self.cursor_mut_at(before);
527 cursor.insert_before(function);
528
529 Ok(())
530 }
531
532 pub fn insert_after(
537 &mut self,
538 function: Box<Function>,
539 after: Ident,
540 ) -> Result<(), SymbolConflictError> {
541 assert_valid_function!(self, function);
542 if let Some(prev) = self.function(function.id.function) {
543 return Err(SymbolConflictError(prev.id));
544 }
545
546 let mut cursor = self.cursor_mut_at(after);
547 if cursor.is_null() {
548 cursor.insert_before(function);
549 } else {
550 cursor.insert_after(function);
551 }
552
553 Ok(())
554 }
555
556 pub fn pop_front(&mut self) -> Option<Box<Function>> {
558 self.functions.pop_front()
559 }
560
561 #[inline]
567 pub fn cursor_mut<'a, 'b: 'a>(&'b mut self) -> ModuleCursor<'a> {
568 ModuleCursor {
569 cursor: self.functions.front_mut(),
570 name: self.name,
571 is_kernel: self.is_kernel,
572 }
573 }
574
575 pub fn cursor_at<'a, 'b: 'a>(&'b self, id: Ident) -> Cursor<'a, FunctionListAdapter> {
579 let mut cursor = self.functions.front();
580 while let Some(function) = cursor.get() {
581 if function.id.function == id {
582 break;
583 }
584 cursor.move_next();
585 }
586 cursor
587 }
588
589 pub fn cursor_mut_at<'a, 'b: 'a>(&'b mut self, id: Ident) -> ModuleCursor<'a> {
593 let mut cursor = self.functions.front_mut();
594 while let Some(function) = cursor.get() {
595 if function.id.function == id {
596 break;
597 }
598 cursor.move_next();
599 }
600 ModuleCursor {
601 cursor,
602 name: self.name,
603 is_kernel: self.is_kernel,
604 }
605 }
606}
607
608pub struct ModuleCursor<'a> {
609 cursor: CursorMut<'a, FunctionListAdapter>,
610 name: Ident,
611 is_kernel: bool,
612}
613impl<'a> ModuleCursor<'a> {
614 #[inline(always)]
616 pub fn is_null(&self) -> bool {
617 self.cursor.is_null()
618 }
619
620 #[inline(always)]
624 pub fn get(&self) -> Option<&Function> {
625 self.cursor.get()
626 }
627
628 pub fn insert_after(&mut self, function: Box<Function>) {
635 assert_valid_function!(self, function);
636 self.cursor.insert_after(function);
637 }
638
639 pub fn insert_before(&mut self, function: Box<Function>) {
646 assert_valid_function!(self, function);
647 self.cursor.insert_before(function);
648 }
649
650 #[inline(always)]
655 pub fn move_next(&mut self) {
656 self.cursor.move_next();
657 }
658
659 #[inline(always)]
664 pub fn move_prev(&mut self) {
665 self.cursor.move_prev();
666 }
667
668 #[inline(always)]
674 pub fn peek_next(&self) -> Cursor<'_, FunctionListAdapter> {
675 self.cursor.peek_next()
676 }
677
678 #[inline(always)]
684 pub fn peek_prev(&self) -> Cursor<'_, FunctionListAdapter> {
685 self.cursor.peek_prev()
686 }
687
688 #[inline(always)]
693 pub fn remove(&mut self) -> Option<Box<Function>> {
694 self.cursor.remove()
695 }
696}
697
698pub struct ModuleBuilder {
699 module: Box<Module>,
700}
701impl From<Box<Module>> for ModuleBuilder {
702 fn from(module: Box<Module>) -> Self {
703 Self { module }
704 }
705}
706impl ModuleBuilder {
707 pub fn new<S: Into<Ident>>(name: S) -> Self {
708 Self {
709 module: Box::new(Module::new(name)),
710 }
711 }
712
713 pub fn new_kernel<S: Into<Ident>>(name: S) -> Self {
714 Self {
715 module: Box::new(Module::new_kernel(name)),
716 }
717 }
718
719 pub fn with_span(&mut self, span: SourceSpan) -> &mut Self {
720 self.module.name = Ident::new(self.module.name.as_symbol(), span);
721 self
722 }
723
724 pub fn with_docs<S: Into<String>>(&mut self, docs: S) -> &mut Self {
725 self.module.docs = Some(docs.into());
726 self
727 }
728
729 pub fn with_page_size(&mut self, page_size: u32) -> &mut Self {
730 self.module.page_size = page_size;
731 self
732 }
733
734 pub fn with_reserved_memory_pages(&mut self, num_pages: u32) -> &mut Self {
735 self.module.reserved_memory_pages = num_pages;
736 self
737 }
738
739 pub fn name(&self) -> Ident {
740 self.module.name
741 }
742
743 pub fn declare_global_variable<S: AsRef<str>>(
744 &mut self,
745 name: S,
746 ty: Type,
747 linkage: Linkage,
748 init: Option<ConstantData>,
749 span: SourceSpan,
750 ) -> Result<GlobalVariable, GlobalVariableError> {
751 let name = Ident::new(Symbol::intern(name.as_ref()), span);
752 self.module.declare_global_variable(name, ty, linkage, init)
753 }
754
755 pub fn set_global_initializer(
756 &mut self,
757 gv: GlobalVariable,
758 init: ConstantData,
759 ) -> Result<(), GlobalVariableError> {
760 self.module.set_global_initializer(gv, init)
761 }
762
763 pub fn declare_data_segment<I: Into<ConstantData>>(
764 &mut self,
765 offset: Offset,
766 size: u32,
767 init: I,
768 readonly: bool,
769 ) -> Result<(), DataSegmentError> {
770 self.module.declare_data_segment(offset, size, init.into(), readonly)
771 }
772
773 pub fn function<'a, 'b: 'a, S: Into<Ident>>(
775 &'b mut self,
776 name: S,
777 signature: Signature,
778 ) -> Result<ModuleFunctionBuilder<'a>, SymbolConflictError> {
779 let name = name.into();
780 if let Some(prev) = self.module.function(name) {
781 return Err(SymbolConflictError(prev.id));
782 }
783
784 let id = FunctionIdent {
785 module: self.module.name,
786 function: name,
787 };
788 let function = Box::new(Function::new(id, signature));
789 let entry = function.dfg.entry;
790
791 Ok(ModuleFunctionBuilder {
792 builder: self,
793 function,
794 position: entry,
795 })
796 }
797
798 pub fn build(self) -> Box<Module> {
799 self.module
800 }
801}
802
803pub struct ModuleFunctionBuilder<'m> {
804 builder: &'m mut ModuleBuilder,
805 function: Box<Function>,
806 position: Block,
807}
808impl<'m> ModuleFunctionBuilder<'m> {
809 pub fn with_span(&mut self, span: SourceSpan) -> &mut Self {
810 self.function.id.function = Ident::new(self.function.id.function.as_symbol(), span);
811 self
812 }
813
814 pub fn id(&self) -> FunctionIdent {
816 self.function.id
817 }
818
819 pub fn signature(&self) -> &Signature {
821 &self.function.signature
822 }
823
824 pub fn module<'a, 'b: 'a>(&'b mut self) -> &'a mut ModuleBuilder {
825 self.builder
826 }
827
828 #[inline(always)]
829 pub fn data_flow_graph(&self) -> &DataFlowGraph {
830 &self.function.dfg
831 }
832
833 #[inline(always)]
834 pub fn data_flow_graph_mut(&mut self) -> &mut DataFlowGraph {
835 &mut self.function.dfg
836 }
837
838 #[inline]
839 pub fn entry_block(&self) -> Block {
840 self.function.dfg.entry
841 }
842
843 #[inline]
844 pub fn current_block(&self) -> Block {
845 self.position
846 }
847
848 #[inline]
849 pub fn switch_to_block(&mut self, block: Block) {
850 self.position = block;
851 }
852
853 pub fn create_block(&mut self) -> Block {
854 self.data_flow_graph_mut().create_block()
855 }
856
857 pub fn block_params(&self, block: Block) -> &[Value] {
858 self.data_flow_graph().block_params(block)
859 }
860
861 pub fn append_block_param(&mut self, block: Block, ty: Type, span: SourceSpan) -> Value {
862 self.data_flow_graph_mut().append_block_param(block, ty, span)
863 }
864
865 pub fn inst_results(&self, inst: Inst) -> &[Value] {
866 self.data_flow_graph().inst_results(inst)
867 }
868
869 pub fn first_result(&self, inst: Inst) -> Value {
870 self.data_flow_graph().first_result(inst)
871 }
872
873 pub fn set_attribute(&mut self, name: impl Into<Symbol>, value: impl Into<AttributeValue>) {
874 self.data_flow_graph_mut().set_attribute(name, value);
875 }
876
877 pub fn import_function<M, F>(
878 &mut self,
879 module: M,
880 function: F,
881 signature: Signature,
882 ) -> Result<FunctionIdent, SymbolConflictError>
883 where
884 M: Into<Ident>,
885 F: Into<Ident>,
886 {
887 self.function.dfg.import_function(module.into(), function.into(), signature)
888 }
889
890 pub fn ins<'a, 'b: 'a>(&'b mut self) -> DefaultInstBuilder<'a> {
891 DefaultInstBuilder::new(&mut self.function.dfg, self.position)
892 }
893
894 pub fn build(self, diagnostics: &DiagnosticsHandler) -> Result<FunctionIdent, Report> {
895 let sig = self.function.signature();
896 match sig.linkage {
897 Linkage::External | Linkage::Internal => (),
898 linkage => {
899 return Err(diagnostics
900 .diagnostic(Severity::Error)
901 .with_message("invalid function definition")
902 .with_primary_label(
903 self.function.span(),
904 format!("invalid linkage: '{linkage}'"),
905 )
906 .with_help("Only 'external' and 'internal' linkage are valid for functions")
907 .into_report());
908 }
909 }
910
911 let is_kernel_module = self.builder.module.is_kernel;
912 let is_public = sig.is_public();
913
914 match sig.cc {
915 CallConv::Kernel if is_kernel_module => {
916 if !is_public {
917 return Err(diagnostics
918 .diagnostic(Severity::Error)
919 .with_message("invalid function definition")
920 .with_primary_label(
921 self.function.span(),
922 format!("expected 'external' linkage, but got '{}'", &sig.linkage),
923 )
924 .with_help(
925 "Functions declared with the 'kernel' calling convention must have \
926 'external' linkage",
927 )
928 .into_report());
929 }
930 }
931 CallConv::Kernel => {
932 return Err(diagnostics
933 .diagnostic(Severity::Error)
934 .with_message("invalid function definition")
935 .with_primary_label(
936 self.function.span(),
937 "unsupported use of 'kernel' calling convention",
938 )
939 .with_help(
940 "The 'kernel' calling convention is only allowed in kernel modules, on \
941 functions with external linkage",
942 )
943 .into_report());
944 }
945 cc if is_kernel_module && is_public => {
946 return Err(diagnostics
947 .diagnostic(Severity::Error)
948 .with_message("invalid function definition")
949 .with_primary_label(
950 self.function.span(),
951 format!("unsupported use of '{cc}' calling convention"),
952 )
953 .with_help(
954 "Functions with external linkage, must use the 'kernel' calling \
955 convention when defined in a kernel module",
956 )
957 .into_report());
958 }
959 _ => (),
960 }
961
962 let id = self.function.id;
963 self.builder.module.functions.push_back(self.function);
964
965 Ok(id)
966 }
967}