dusk_dire/
mir.rs

1use std::collections::HashMap;
2use std::ffi::CString;
3
4use index_vec::{IndexVec, define_index_type};
5use smallvec::SmallVec;
6use string_interner::DefaultSymbol as Sym;
7use display_adapter::display_adapter;
8use num_bigint::BigInt;
9
10use crate::hir::{Intrinsic, DeclId, StructId, EnumId, ModScopeId, ExternModId, ExternFunctionRef, GenericParamId};
11use crate::ty::{Type, InternalType, FunctionType, StructType};
12use crate::{Code, BlockId, OpId, InternalField};
13use crate::source_info::SourceRange;
14
15define_index_type!(pub struct FuncId = u32;);
16define_index_type!(pub struct StaticId = u32;);
17define_index_type!(pub struct StrId = u32;);
18define_index_type!(pub struct InstrId = u32;);
19
20pub const VOID_INSTR: OpId = OpId::from_usize_unchecked(0);
21
22#[derive(Clone, Debug, PartialEq)]
23pub struct SwitchCase {
24    pub value: Const,
25    pub bb: BlockId,
26}
27
28#[derive(Clone, Debug, PartialEq)]
29pub enum Instr {
30    Void,
31    Invalid, // Used temporarily when copying functions
32    Const(Const),
33    Alloca(Type),
34    LogicalNot(OpId),
35    Call { arguments: SmallVec<[OpId; 2]>, generic_arguments: Vec<Type>, func: FuncId },
36    ExternCall { arguments: SmallVec<[OpId; 2]>, func: ExternFunctionRef },
37    FunctionRef { generic_arguments: Vec<Type>, func: FuncId, },
38    Intrinsic { arguments: SmallVec<[OpId; 2]>, ty: Type, intr: Intrinsic },
39    Import(OpId), // TODO: make this an intrinsic, or even a core library function
40    Reinterpret(OpId, Type),
41    Truncate(OpId, Type),
42    SignExtend(OpId, Type),
43    ZeroExtend(OpId, Type),
44    FloatCast(OpId, Type),
45    FloatToInt(OpId, Type),
46    IntToFloat(OpId, Type),
47    Load(OpId),
48    Store { location: OpId, value: OpId },
49    AddressOfStatic(StaticId),
50    Pointer { op: OpId, is_mut: bool },
51    Struct { fields: SmallVec<[OpId; 2]>, id: StructId },
52    Enum { variants: SmallVec<[OpId; 2]>, id: EnumId },
53    FunctionTy { param_tys: Vec<OpId>, ret_ty: OpId },
54    StructLit { fields: SmallVec<[OpId; 2]>, id: StructId },
55    DirectFieldAccess { val: OpId, index: usize },
56    IndirectFieldAccess { val: OpId, index: usize },
57    InternalFieldAccess { val: OpId, field: InternalField },
58    Variant { enuum: EnumId, index: usize, payload: OpId },
59    DiscriminantAccess { val: OpId },
60    Ret(OpId),
61    Br(BlockId),
62    CondBr { condition: OpId, true_bb: BlockId, false_bb: BlockId },
63    SwitchBr { scrutinee: OpId, cases: Vec<SwitchCase>, catch_all_bb: BlockId },
64    GenericParam(GenericParamId),
65    /// Only valid at the beginning of a function, right after the void instruction
66    // TODO: Get rid of the type here! It is no longer required because instruction types are now stored on each Op
67    Parameter(Type),
68}
69
70impl Instr {
71    pub fn replace_bb(&mut self, old: BlockId, new: BlockId) {
72        fn replace(target: &mut BlockId, old: BlockId, new: BlockId) {
73            if *target == old {
74                *target = new;
75            }
76        }
77        match self {
78            Instr::Br(bb) => replace(bb, old, new),
79            Instr::CondBr { true_bb, false_bb, .. } => {
80                replace(true_bb, old, new);
81                replace(false_bb, old, new);
82            },
83            Instr::SwitchBr { cases, catch_all_bb, .. } => {
84                for case in cases {
85                    replace(&mut case.bb, old, new);
86                }
87                replace(catch_all_bb, old, new);
88            },
89            _ => {}
90        }
91    }
92
93    // TODO: allocating a Vec here sucks!
94    pub fn referenced_values(&self) -> Vec<OpId> {
95        match *self {
96            Instr::Void | Instr::Const(_) | Instr::Alloca(_) | Instr::AddressOfStatic(_) | Instr::Br(_)
97                | Instr::GenericParam(_) | Instr::Parameter(_) | Instr::FunctionRef { .. } | Instr::Invalid => vec![],
98            Instr::LogicalNot(op) | Instr::Reinterpret(op, _) | Instr::Truncate(op, _) | Instr::SignExtend(op, _)
99                | Instr::ZeroExtend(op, _) | Instr::FloatCast(op, _) | Instr::FloatToInt(op, _)
100                | Instr::IntToFloat(op, _) | Instr::Load(op) | Instr::Pointer { op, .. }
101                | Instr::DirectFieldAccess { val: op, .. } | Instr::IndirectFieldAccess { val: op, .. }
102                | Instr::DiscriminantAccess { val: op } | Instr::Ret(op) | Instr::CondBr { condition: op, .. }
103                | Instr::SwitchBr { scrutinee: op, .. } | Instr::Variant { payload: op, .. }
104                | Instr::InternalFieldAccess { val: op, .. } | Instr::Import(op) => vec![op],
105            Instr::Store { location, value } => vec![location, value],
106            Instr::Call { arguments: ref ops, .. } | Instr::ExternCall { arguments: ref ops, .. }
107                | Instr::Intrinsic { arguments: ref ops, .. } | Instr::Struct { fields: ref ops, .. }
108                | Instr::Enum { variants: ref ops, .. } | Instr::StructLit { fields: ref ops, .. } => ops.iter().copied().collect(),
109            Instr::FunctionTy { ref param_tys, ret_ty } => param_tys.iter().copied().chain(std::iter::once(ret_ty)).collect(),
110        }
111    }
112
113    pub fn references_value(&self, val: OpId) -> bool {
114        self.referenced_values().iter().any(|&referenced| referenced == val)
115    }
116
117    pub fn replace_value(&mut self, old: OpId, new: OpId) {
118        fn replace(target: &mut OpId, old: OpId, new: OpId) {
119            if *target == old {
120                *target = new;
121            }
122        }
123        match self {
124            Instr::Void | Instr::Const(_) | Instr::Alloca(_) | Instr::AddressOfStatic(_) | Instr::Br(_)
125                | Instr::GenericParam(_) | Instr::Parameter(_) | Instr::FunctionRef { .. } | Instr::Invalid => {},
126            Instr::LogicalNot(op) | Instr::Reinterpret(op, _) | Instr::Truncate(op, _) | Instr::SignExtend(op, _)
127                | Instr::ZeroExtend(op, _) | Instr::FloatCast(op, _) | Instr::FloatToInt(op, _)
128                | Instr::IntToFloat(op, _) | Instr::Load(op) | Instr::Pointer { op, .. }
129                | Instr::DirectFieldAccess { val: op, .. } | Instr::IndirectFieldAccess { val: op, .. }
130                | Instr::DiscriminantAccess { val: op } | Instr::Ret(op) | Instr::CondBr { condition: op, .. }
131                | Instr::SwitchBr { scrutinee: op, .. } | Instr::Variant { payload: op, .. }
132                | Instr::InternalFieldAccess { val: op, .. } | Instr::Import(op) => replace(op, old, new),
133            Instr::Store { location, value } => {
134                replace(location, old, new);
135                replace(value, old, new);
136            },
137            Instr::Call { arguments: ref mut ops, .. } | Instr::ExternCall { arguments: ref mut ops, .. }
138                | Instr::Intrinsic { arguments: ref mut ops, .. } | Instr::Struct { fields: ref mut ops, .. }
139                | Instr::Enum { variants: ref mut ops, .. } | Instr::StructLit { fields: ref mut ops, .. } => {
140                    for op in ops {
141                        replace(op, old, new);
142                    }
143                }
144            Instr::FunctionTy { ref mut param_tys, ref mut ret_ty } => {
145                for op in param_tys {
146                    replace(op, old, new);
147                }
148                replace(ret_ty, old, new);
149            }
150        }
151    }
152}
153
154#[derive(Clone, Debug, PartialEq)]
155pub enum Const {
156    Int { lit: BigInt, ty: Type },
157    Float { lit: f64, ty: Type },
158    Str { id: StrId, ty: Type },
159    /// A compile-time known string that comes from a string literal. This will be used in the future to convert
160    /// to some user-defined type at compile-time.
161    StrLit(CString),
162    Bool(bool),
163    Ty(Type),
164    Mod(ModScopeId),
165    BasicVariant { enuum: EnumId, index: usize },
166    StructLit { fields: Vec<Const>, id: StructId },
167    Void,
168}
169
170impl Const {
171    pub fn ty(&self) -> Type {
172        match self {
173            Const::Int { ty, .. } | Const::Float { ty, .. } | Const::Str { ty, .. } => ty.clone(),
174            Const::StrLit(_) => Type::Internal(InternalType::StringLiteral),
175            Const::Bool(_) => Type::Bool,
176            Const::Ty(_) => Type::Ty,
177            &Const::BasicVariant { enuum, .. } => Type::Enum(enuum),
178            Const::Mod(_) => Type::Mod,
179            &Const::StructLit { id, ref fields } => Type::Struct(
180                StructType {
181                    field_tys: fields.iter().map(|field| field.ty()).collect(),
182                    identity: id,
183                }
184            ),
185            Const::Void => Type::Void,
186        }
187    }
188}
189
190#[derive(Default, Debug, Clone)]
191pub struct InstrNamespace {
192    name_usages: HashMap<String, u16>,
193}
194
195impl InstrNamespace {
196    pub fn insert(&mut self, name: impl Into<String>) -> String {
197        let mut name = name.into();
198        let entry = self.name_usages.entry(name.clone()).or_default();
199        if *entry > 0 {
200            name = format!("{}.{}", name, *entry);
201        }
202        *entry += 1;
203        name
204    }
205}
206
207#[derive(Debug, Default, Clone)]
208pub struct Function {
209    pub name: Option<Sym>,
210    pub ty: FunctionType,
211    pub num_instrs: usize,
212    /// Index 0 is defined to be the entry block
213    pub blocks: Vec<BlockId>,
214    pub decl: Option<DeclId>,
215    // Note: Is a Vec, not a Range, because generic params might not always be contiguous in
216    // GenericParamId space
217    pub generic_params: Vec<GenericParamId>,
218    pub instr_namespace: InstrNamespace,
219    pub is_comptime: bool,
220}
221
222impl Code {
223    pub fn num_parameters(&self, func: &Function) -> usize {
224        let entry = func.blocks[0];
225        let block = &self.blocks[entry];
226        block.ops.iter()
227            .filter(|&&op| matches!(self.ops[op].as_mir_instr().unwrap(), Instr::Parameter(_)))
228            .count()
229    }
230
231    #[display_adapter]
232    pub fn display_func(&self, func: &Function, name: &str, w: &mut Formatter) {
233        writeln!(w, "fn {}() {{", name)?;
234        for &block in &func.blocks {
235            write!(w, "%bb{}:\n{}", block.index(), self.display_block(block))?;
236        }
237        writeln!(w, "}}")?;
238        Ok(())
239    }
240}
241
242#[derive(Clone)]
243pub struct StructLayout {
244    pub field_offsets: SmallVec<[usize; 2]>,
245    pub alignment: usize,
246    pub size: usize,
247    pub stride: usize,
248}
249
250#[derive(Clone)]
251pub struct EnumLayout {
252    pub payload_offsets: SmallVec<[usize; 2]>,
253    pub alignment: usize,
254    pub size: usize,
255    pub stride: usize,
256}
257
258#[derive(Debug)]
259pub enum BlockState {
260    Created,
261    Started,
262    Ended,
263}
264
265pub struct Static {
266    pub name: String,
267    pub val: Const,
268}
269
270pub struct ExternMod {
271    pub library_path: CString,
272    pub imported_functions: Vec<ExternFunction>,
273}
274
275#[derive(Debug)]
276pub struct ExternFunction {
277    pub name: String,
278    pub ty: FunctionType,
279}
280
281pub struct MirCode {
282    pub strings: IndexVec<StrId, CString>,
283    pub functions: IndexVec<FuncId, Function>,
284    pub statics: IndexVec<StaticId, Static>,
285    pub extern_mods: HashMap<ExternModId, ExternMod>,
286    pub enums: HashMap<EnumId, EnumLayout>,
287    pub source_ranges: HashMap<OpId, SourceRange>,
288    pub instr_names: HashMap<OpId, String>,
289    block_states: HashMap<BlockId, BlockState>,
290}
291
292#[derive(Debug)]
293pub enum StartBlockError {
294    BlockEnded,
295}
296
297#[derive(Debug)]
298pub enum EndBlockError {
299    BlockEnded,
300    BlockNotStarted,
301}
302
303impl MirCode {
304    pub fn new() -> Self {
305        MirCode {
306            strings: IndexVec::new(),
307            functions: IndexVec::new(),
308            statics: IndexVec::new(),
309            extern_mods: HashMap::new(),
310            enums: HashMap::new(),
311            source_ranges: HashMap::new(),
312            instr_names: HashMap::new(),
313            block_states: HashMap::new(),
314        }
315    }
316
317    fn get_block_state(&mut self, block: BlockId) -> &mut BlockState {
318        self.block_states.entry(block).or_insert(BlockState::Created)
319    }
320
321    pub fn start_block(&mut self, block: BlockId) -> Result<(), StartBlockError> {
322        let state = self.get_block_state(block);
323        match state {
324            BlockState::Created => {
325                *state = BlockState::Started;
326                Ok(())
327            }
328            BlockState::Started => Ok(()),
329            BlockState::Ended => Err(StartBlockError::BlockEnded),
330        }
331    }
332
333    pub fn end_block(&mut self, block: BlockId) -> Result<(), EndBlockError> {
334        let state = self.get_block_state(block);
335        match state {
336            BlockState::Created => Err(EndBlockError::BlockNotStarted),
337            BlockState::Started => {
338                *state = BlockState::Ended;
339                Ok(())
340            },
341            BlockState::Ended => Err(EndBlockError::BlockEnded),
342        }
343    }
344
345    pub fn first_unended_block(&self, func: &Function) -> Option<BlockId> {
346        func.blocks.iter().find(|&block| {
347            let state = &self.block_states[block];
348            !matches!(state, BlockState::Ended)
349        }).copied()
350    }
351
352    pub fn check_all_blocks_ended(&self, func: &Function) {
353        if let Some(block) = self.first_unended_block(func) {
354            panic!("MIR: Block {} was not ended", block.index());
355        }
356    }
357}
358
359impl Default for MirCode {
360    fn default() -> Self { Self::new() }
361}