1use std::collections::HashMap;
2use std::ffi::CString;
3
4use index_vec::{IndexVec, define_index_type};
5use smallvec::SmallVec;
6use string_interner::DefaultSymbol as Sym;
7use display_adapter::display_adapter;
8use num_bigint::BigInt;
9
10use crate::hir::{Intrinsic, DeclId, StructId, EnumId, ModScopeId, ExternModId, ExternFunctionRef, GenericParamId};
11use crate::ty::{Type, InternalType, FunctionType, StructType};
12use crate::{Code, BlockId, OpId, InternalField};
13use crate::source_info::SourceRange;
14
15define_index_type!(pub struct FuncId = u32;);
16define_index_type!(pub struct StaticId = u32;);
17define_index_type!(pub struct StrId = u32;);
18define_index_type!(pub struct InstrId = u32;);
19
20pub const VOID_INSTR: OpId = OpId::from_usize_unchecked(0);
21
22#[derive(Clone, Debug, PartialEq)]
23pub struct SwitchCase {
24 pub value: Const,
25 pub bb: BlockId,
26}
27
28#[derive(Clone, Debug, PartialEq)]
29pub enum Instr {
30 Void,
31 Invalid, Const(Const),
33 Alloca(Type),
34 LogicalNot(OpId),
35 Call { arguments: SmallVec<[OpId; 2]>, generic_arguments: Vec<Type>, func: FuncId },
36 ExternCall { arguments: SmallVec<[OpId; 2]>, func: ExternFunctionRef },
37 FunctionRef { generic_arguments: Vec<Type>, func: FuncId, },
38 Intrinsic { arguments: SmallVec<[OpId; 2]>, ty: Type, intr: Intrinsic },
39 Import(OpId), Reinterpret(OpId, Type),
41 Truncate(OpId, Type),
42 SignExtend(OpId, Type),
43 ZeroExtend(OpId, Type),
44 FloatCast(OpId, Type),
45 FloatToInt(OpId, Type),
46 IntToFloat(OpId, Type),
47 Load(OpId),
48 Store { location: OpId, value: OpId },
49 AddressOfStatic(StaticId),
50 Pointer { op: OpId, is_mut: bool },
51 Struct { fields: SmallVec<[OpId; 2]>, id: StructId },
52 Enum { variants: SmallVec<[OpId; 2]>, id: EnumId },
53 FunctionTy { param_tys: Vec<OpId>, ret_ty: OpId },
54 StructLit { fields: SmallVec<[OpId; 2]>, id: StructId },
55 DirectFieldAccess { val: OpId, index: usize },
56 IndirectFieldAccess { val: OpId, index: usize },
57 InternalFieldAccess { val: OpId, field: InternalField },
58 Variant { enuum: EnumId, index: usize, payload: OpId },
59 DiscriminantAccess { val: OpId },
60 Ret(OpId),
61 Br(BlockId),
62 CondBr { condition: OpId, true_bb: BlockId, false_bb: BlockId },
63 SwitchBr { scrutinee: OpId, cases: Vec<SwitchCase>, catch_all_bb: BlockId },
64 GenericParam(GenericParamId),
65 Parameter(Type),
68}
69
70impl Instr {
71 pub fn replace_bb(&mut self, old: BlockId, new: BlockId) {
72 fn replace(target: &mut BlockId, old: BlockId, new: BlockId) {
73 if *target == old {
74 *target = new;
75 }
76 }
77 match self {
78 Instr::Br(bb) => replace(bb, old, new),
79 Instr::CondBr { true_bb, false_bb, .. } => {
80 replace(true_bb, old, new);
81 replace(false_bb, old, new);
82 },
83 Instr::SwitchBr { cases, catch_all_bb, .. } => {
84 for case in cases {
85 replace(&mut case.bb, old, new);
86 }
87 replace(catch_all_bb, old, new);
88 },
89 _ => {}
90 }
91 }
92
93 pub fn referenced_values(&self) -> Vec<OpId> {
95 match *self {
96 Instr::Void | Instr::Const(_) | Instr::Alloca(_) | Instr::AddressOfStatic(_) | Instr::Br(_)
97 | Instr::GenericParam(_) | Instr::Parameter(_) | Instr::FunctionRef { .. } | Instr::Invalid => vec![],
98 Instr::LogicalNot(op) | Instr::Reinterpret(op, _) | Instr::Truncate(op, _) | Instr::SignExtend(op, _)
99 | Instr::ZeroExtend(op, _) | Instr::FloatCast(op, _) | Instr::FloatToInt(op, _)
100 | Instr::IntToFloat(op, _) | Instr::Load(op) | Instr::Pointer { op, .. }
101 | Instr::DirectFieldAccess { val: op, .. } | Instr::IndirectFieldAccess { val: op, .. }
102 | Instr::DiscriminantAccess { val: op } | Instr::Ret(op) | Instr::CondBr { condition: op, .. }
103 | Instr::SwitchBr { scrutinee: op, .. } | Instr::Variant { payload: op, .. }
104 | Instr::InternalFieldAccess { val: op, .. } | Instr::Import(op) => vec![op],
105 Instr::Store { location, value } => vec![location, value],
106 Instr::Call { arguments: ref ops, .. } | Instr::ExternCall { arguments: ref ops, .. }
107 | Instr::Intrinsic { arguments: ref ops, .. } | Instr::Struct { fields: ref ops, .. }
108 | Instr::Enum { variants: ref ops, .. } | Instr::StructLit { fields: ref ops, .. } => ops.iter().copied().collect(),
109 Instr::FunctionTy { ref param_tys, ret_ty } => param_tys.iter().copied().chain(std::iter::once(ret_ty)).collect(),
110 }
111 }
112
113 pub fn references_value(&self, val: OpId) -> bool {
114 self.referenced_values().iter().any(|&referenced| referenced == val)
115 }
116
117 pub fn replace_value(&mut self, old: OpId, new: OpId) {
118 fn replace(target: &mut OpId, old: OpId, new: OpId) {
119 if *target == old {
120 *target = new;
121 }
122 }
123 match self {
124 Instr::Void | Instr::Const(_) | Instr::Alloca(_) | Instr::AddressOfStatic(_) | Instr::Br(_)
125 | Instr::GenericParam(_) | Instr::Parameter(_) | Instr::FunctionRef { .. } | Instr::Invalid => {},
126 Instr::LogicalNot(op) | Instr::Reinterpret(op, _) | Instr::Truncate(op, _) | Instr::SignExtend(op, _)
127 | Instr::ZeroExtend(op, _) | Instr::FloatCast(op, _) | Instr::FloatToInt(op, _)
128 | Instr::IntToFloat(op, _) | Instr::Load(op) | Instr::Pointer { op, .. }
129 | Instr::DirectFieldAccess { val: op, .. } | Instr::IndirectFieldAccess { val: op, .. }
130 | Instr::DiscriminantAccess { val: op } | Instr::Ret(op) | Instr::CondBr { condition: op, .. }
131 | Instr::SwitchBr { scrutinee: op, .. } | Instr::Variant { payload: op, .. }
132 | Instr::InternalFieldAccess { val: op, .. } | Instr::Import(op) => replace(op, old, new),
133 Instr::Store { location, value } => {
134 replace(location, old, new);
135 replace(value, old, new);
136 },
137 Instr::Call { arguments: ref mut ops, .. } | Instr::ExternCall { arguments: ref mut ops, .. }
138 | Instr::Intrinsic { arguments: ref mut ops, .. } | Instr::Struct { fields: ref mut ops, .. }
139 | Instr::Enum { variants: ref mut ops, .. } | Instr::StructLit { fields: ref mut ops, .. } => {
140 for op in ops {
141 replace(op, old, new);
142 }
143 }
144 Instr::FunctionTy { ref mut param_tys, ref mut ret_ty } => {
145 for op in param_tys {
146 replace(op, old, new);
147 }
148 replace(ret_ty, old, new);
149 }
150 }
151 }
152}
153
154#[derive(Clone, Debug, PartialEq)]
155pub enum Const {
156 Int { lit: BigInt, ty: Type },
157 Float { lit: f64, ty: Type },
158 Str { id: StrId, ty: Type },
159 StrLit(CString),
162 Bool(bool),
163 Ty(Type),
164 Mod(ModScopeId),
165 BasicVariant { enuum: EnumId, index: usize },
166 StructLit { fields: Vec<Const>, id: StructId },
167 Void,
168}
169
170impl Const {
171 pub fn ty(&self) -> Type {
172 match self {
173 Const::Int { ty, .. } | Const::Float { ty, .. } | Const::Str { ty, .. } => ty.clone(),
174 Const::StrLit(_) => Type::Internal(InternalType::StringLiteral),
175 Const::Bool(_) => Type::Bool,
176 Const::Ty(_) => Type::Ty,
177 &Const::BasicVariant { enuum, .. } => Type::Enum(enuum),
178 Const::Mod(_) => Type::Mod,
179 &Const::StructLit { id, ref fields } => Type::Struct(
180 StructType {
181 field_tys: fields.iter().map(|field| field.ty()).collect(),
182 identity: id,
183 }
184 ),
185 Const::Void => Type::Void,
186 }
187 }
188}
189
190#[derive(Default, Debug, Clone)]
191pub struct InstrNamespace {
192 name_usages: HashMap<String, u16>,
193}
194
195impl InstrNamespace {
196 pub fn insert(&mut self, name: impl Into<String>) -> String {
197 let mut name = name.into();
198 let entry = self.name_usages.entry(name.clone()).or_default();
199 if *entry > 0 {
200 name = format!("{}.{}", name, *entry);
201 }
202 *entry += 1;
203 name
204 }
205}
206
207#[derive(Debug, Default, Clone)]
208pub struct Function {
209 pub name: Option<Sym>,
210 pub ty: FunctionType,
211 pub num_instrs: usize,
212 pub blocks: Vec<BlockId>,
214 pub decl: Option<DeclId>,
215 pub generic_params: Vec<GenericParamId>,
218 pub instr_namespace: InstrNamespace,
219 pub is_comptime: bool,
220}
221
222impl Code {
223 pub fn num_parameters(&self, func: &Function) -> usize {
224 let entry = func.blocks[0];
225 let block = &self.blocks[entry];
226 block.ops.iter()
227 .filter(|&&op| matches!(self.ops[op].as_mir_instr().unwrap(), Instr::Parameter(_)))
228 .count()
229 }
230
231 #[display_adapter]
232 pub fn display_func(&self, func: &Function, name: &str, w: &mut Formatter) {
233 writeln!(w, "fn {}() {{", name)?;
234 for &block in &func.blocks {
235 write!(w, "%bb{}:\n{}", block.index(), self.display_block(block))?;
236 }
237 writeln!(w, "}}")?;
238 Ok(())
239 }
240}
241
242#[derive(Clone)]
243pub struct StructLayout {
244 pub field_offsets: SmallVec<[usize; 2]>,
245 pub alignment: usize,
246 pub size: usize,
247 pub stride: usize,
248}
249
250#[derive(Clone)]
251pub struct EnumLayout {
252 pub payload_offsets: SmallVec<[usize; 2]>,
253 pub alignment: usize,
254 pub size: usize,
255 pub stride: usize,
256}
257
258#[derive(Debug)]
259pub enum BlockState {
260 Created,
261 Started,
262 Ended,
263}
264
265pub struct Static {
266 pub name: String,
267 pub val: Const,
268}
269
270pub struct ExternMod {
271 pub library_path: CString,
272 pub imported_functions: Vec<ExternFunction>,
273}
274
275#[derive(Debug)]
276pub struct ExternFunction {
277 pub name: String,
278 pub ty: FunctionType,
279}
280
281pub struct MirCode {
282 pub strings: IndexVec<StrId, CString>,
283 pub functions: IndexVec<FuncId, Function>,
284 pub statics: IndexVec<StaticId, Static>,
285 pub extern_mods: HashMap<ExternModId, ExternMod>,
286 pub enums: HashMap<EnumId, EnumLayout>,
287 pub source_ranges: HashMap<OpId, SourceRange>,
288 pub instr_names: HashMap<OpId, String>,
289 block_states: HashMap<BlockId, BlockState>,
290}
291
292#[derive(Debug)]
293pub enum StartBlockError {
294 BlockEnded,
295}
296
297#[derive(Debug)]
298pub enum EndBlockError {
299 BlockEnded,
300 BlockNotStarted,
301}
302
303impl MirCode {
304 pub fn new() -> Self {
305 MirCode {
306 strings: IndexVec::new(),
307 functions: IndexVec::new(),
308 statics: IndexVec::new(),
309 extern_mods: HashMap::new(),
310 enums: HashMap::new(),
311 source_ranges: HashMap::new(),
312 instr_names: HashMap::new(),
313 block_states: HashMap::new(),
314 }
315 }
316
317 fn get_block_state(&mut self, block: BlockId) -> &mut BlockState {
318 self.block_states.entry(block).or_insert(BlockState::Created)
319 }
320
321 pub fn start_block(&mut self, block: BlockId) -> Result<(), StartBlockError> {
322 let state = self.get_block_state(block);
323 match state {
324 BlockState::Created => {
325 *state = BlockState::Started;
326 Ok(())
327 }
328 BlockState::Started => Ok(()),
329 BlockState::Ended => Err(StartBlockError::BlockEnded),
330 }
331 }
332
333 pub fn end_block(&mut self, block: BlockId) -> Result<(), EndBlockError> {
334 let state = self.get_block_state(block);
335 match state {
336 BlockState::Created => Err(EndBlockError::BlockNotStarted),
337 BlockState::Started => {
338 *state = BlockState::Ended;
339 Ok(())
340 },
341 BlockState::Ended => Err(EndBlockError::BlockEnded),
342 }
343 }
344
345 pub fn first_unended_block(&self, func: &Function) -> Option<BlockId> {
346 func.blocks.iter().find(|&block| {
347 let state = &self.block_states[block];
348 !matches!(state, BlockState::Ended)
349 }).copied()
350 }
351
352 pub fn check_all_blocks_ended(&self, func: &Function) {
353 if let Some(block) = self.first_unended_block(func) {
354 panic!("MIR: Block {} was not ended", block.index());
355 }
356 }
357}
358
359impl Default for MirCode {
360 fn default() -> Self { Self::new() }
361}