dpc/passes/opt/multifold/
assign.rs

1use rustc_hash::FxHashMap;
2
3use crate::common::condition::Condition;
4use crate::common::op::Operation;
5use crate::common::reg::GetUsedRegs;
6use crate::common::ty::{DataTypeContents, ScoreTypeContents};
7use crate::common::DeclareBinding;
8use crate::common::{val::MutableValue, val::Value, Identifier};
9use crate::mir::{MIRBlock, MIRInstrKind, MIRInstruction};
10use crate::passes::{MIRPass, MIRPassData, Pass};
11use crate::project::{OptimizationLevel, ProjectSettings};
12use crate::util::{remove_indices, HashSetEmptyTracker, Only};
13
14pub struct MultifoldAssignPass;
15
16impl Pass for MultifoldAssignPass {
17	fn get_name(&self) -> &'static str {
18		"multifold_assign"
19	}
20
21	fn should_run(&self, proj: &ProjectSettings) -> bool {
22		proj.op_level >= OptimizationLevel::More
23	}
24}
25
26impl MIRPass for MultifoldAssignPass {
27	fn run_pass(&mut self, data: &mut MIRPassData) -> anyhow::Result<()> {
28		for func in data.mir.functions.values_mut() {
29			let block = &mut func.block;
30
31			let mut removed = HashSetEmptyTracker::new();
32			let mut replaced = Vec::new();
33			loop {
34				let run_again = run_iter(block, &mut removed, &mut replaced);
35				if !run_again {
36					break;
37				}
38			}
39			remove_indices(&mut block.contents, &removed);
40		}
41
42		Ok(())
43	}
44}
45
46fn run_iter(
47	block: &mut MIRBlock,
48	removed: &mut HashSetEmptyTracker<usize>,
49	replaced: &mut Vec<(usize, Vec<MIRInstruction>)>,
50) -> bool {
51	let _ = replaced;
52	let mut run_again = false;
53	let mut if_cond_assign = FxHashMap::<Identifier, IfCondAssign>::default();
54	let mut assign_const_add = FxHashMap::<Identifier, AssignConstAdd>::default();
55	let mut overwrite_op = FxHashMap::<Identifier, OverwriteOp>::default();
56	let mut stack_peak = FxHashMap::<Identifier, StackPeak>::default();
57
58	#[derive(Default)]
59	struct RegsToKeep {
60		if_cond_assign: bool,
61		assign_const_add: bool,
62		overwrite_op: bool,
63		stack_peak: bool,
64	}
65
66	for (i, instr) in block.contents.iter().enumerate() {
67		// Even though this instruction hasn't actually been removed from the vec, we treat it
68		// as if it has to prevent doing the same work over and over and actually iterating indefinitely
69		if removed.contains(&i) {
70			continue;
71		}
72
73		let mut regs_to_keep = RegsToKeep::default();
74		let mut dont_create_new_overwrite_op = false;
75
76		match &instr.kind {
77			MIRInstrKind::Assign {
78				left: MutableValue::Reg(left),
79				right,
80			} => {
81				if let Some(fold) = overwrite_op.get_mut(left) {
82					if !fold.finished {
83						fold.right = Some(right.clone());
84						fold.end_pos = i;
85						fold.finished = true;
86						dont_create_new_overwrite_op = true;
87					}
88				}
89
90				if let DeclareBinding::Value(Value::Mutable(MutableValue::Reg(right))) = right {
91					if let Some(fold) = stack_peak.get_mut(right) {
92						if &fold.original_reg == left {
93							if !fold.finished {
94								fold.end_pos = i;
95								fold.finished = true;
96								regs_to_keep.stack_peak = true;
97							}
98						} else {
99							fold.finished = true;
100						}
101					} else {
102						stack_peak.insert(
103							left.clone(),
104							StackPeak {
105								finished: false,
106								start_pos: i,
107								end_pos: i,
108								original_reg: right.clone(),
109								op_poses: Vec::new(),
110								ops: Vec::new(),
111							},
112						);
113						regs_to_keep.stack_peak = true;
114					}
115				}
116			}
117			_ => {}
118		}
119
120		match &instr.kind {
121			MIRInstrKind::Assign {
122				left: MutableValue::Reg(left),
123				right: DeclareBinding::Value(Value::Constant(DataTypeContents::Score(val))),
124			} => {
125				let val = val.get_i32();
126				let invert = if val == 0 {
127					Some(false)
128				} else if val == 1 {
129					Some(true)
130				} else {
131					None
132				};
133				if let Some(invert) = invert {
134					if_cond_assign.insert(
135						left.clone(),
136						IfCondAssign {
137							finished: false,
138							start_pos: i,
139							end_pos: i,
140							invert,
141							condition: None,
142						},
143					);
144					regs_to_keep.if_cond_assign = true;
145				}
146
147				assign_const_add.insert(
148					left.clone(),
149					AssignConstAdd {
150						finished: false,
151						start_pos: i,
152						end_pos: i,
153						const_val: val,
154						right: None,
155					},
156				);
157
158				regs_to_keep.assign_const_add = true;
159			}
160			MIRInstrKind::Add {
161				left: MutableValue::Reg(left),
162				right: Value::Mutable(MutableValue::Reg(right)),
163			} => {
164				if let Some(fold) = assign_const_add.get_mut(left) {
165					if !fold.finished {
166						fold.end_pos = i;
167						fold.right = Some(right.clone());
168						fold.finished = true;
169						regs_to_keep.assign_const_add = true;
170					}
171				}
172			}
173			MIRInstrKind::If { condition, body } => match body.contents.only().map(|x| &x.kind) {
174				Some(MIRInstrKind::Assign {
175					left: MutableValue::Reg(left),
176					right: DeclareBinding::Value(Value::Constant(DataTypeContents::Score(right))),
177				}) => {
178					let right = right.get_i32();
179					if right == 0 || right == 1 {
180						if let Some(fold) = if_cond_assign.get_mut(left) {
181							if !fold.finished {
182								if (right == 1 && !fold.invert) || (right == 0 && fold.invert) {
183									fold.end_pos = i;
184									let mut condition = condition.clone();
185									if fold.invert {
186										condition = Condition::Not(Box::new(condition));
187									}
188									fold.condition = Some(condition);
189								}
190								fold.finished = true;
191							}
192						}
193					}
194				}
195				_ => {}
196			},
197			_ => {}
198		}
199
200		if let Some(MutableValue::Reg(left)) = instr.kind.get_op_lhs() {
201			if !dont_create_new_overwrite_op {
202				overwrite_op.insert(
203					left.clone(),
204					OverwriteOp {
205						finished: false,
206						start_pos: i,
207						end_pos: i,
208						right: None,
209					},
210				);
211				regs_to_keep.overwrite_op = true;
212			}
213
214			let mut remove_stack_peak = false;
215			// Remove stack peaks with the original reg as the lhs
216			for fold in stack_peak.values_mut() {
217				if &fold.original_reg == left {
218					remove_stack_peak = true;
219				}
220			}
221			if let Some(fold) = stack_peak.get_mut(left) {
222				if !fold.finished {
223					let op = Operation::from_instr(instr.kind.clone());
224					if let Some(op) = op {
225						// If the rhs is the same reg as the original reg, we have to invalidate
226						if let Some(Value::Mutable(MutableValue::Reg(rhs))) = op.get_rhs() {
227							if rhs == &fold.original_reg {
228								fold.finished = true;
229								// We have to totally remove it
230								remove_stack_peak = true;
231							}
232						}
233						if !fold.finished {
234							fold.op_poses.push(i);
235							fold.ops.push(op);
236							regs_to_keep.stack_peak = true;
237						}
238					}
239				}
240			}
241			if remove_stack_peak {
242				stack_peak.remove(left);
243			}
244		}
245
246		let used_regs = instr.kind.get_used_regs();
247		for reg in used_regs.into_iter() {
248			if !regs_to_keep.if_cond_assign {
249				if_cond_assign.get_mut(reg).map(|x| x.finished = true);
250			}
251			if !regs_to_keep.assign_const_add {
252				assign_const_add.get_mut(reg).map(|x| x.finished = true);
253				for fold in assign_const_add.values_mut() {
254					if let Some(right) = &fold.right {
255						if right == reg {
256							fold.finished = true;
257						}
258					}
259				}
260			}
261			if !regs_to_keep.overwrite_op {
262				overwrite_op.get_mut(reg).map(|x| x.finished = true);
263			}
264			if !regs_to_keep.stack_peak {
265				stack_peak.retain(|fold_reg, fold| {
266					if fold.finished {
267						true
268					} else {
269						fold_reg != reg && &fold.original_reg != reg
270					}
271				});
272			}
273		}
274	}
275
276	// Finish the folds
277	for (reg, fold) in if_cond_assign {
278		if let Some(condition) = fold.condition {
279			run_again = true;
280			removed.insert(fold.start_pos);
281			block
282				.contents
283				.get_mut(fold.end_pos)
284				.expect("Instr at pos does not exist")
285				.kind = MIRInstrKind::Assign {
286				left: MutableValue::Reg(reg),
287				right: DeclareBinding::Condition(condition),
288			};
289		}
290	}
291
292	for (reg, fold) in assign_const_add {
293		if let Some(right) = fold.right {
294			run_again = true;
295			block
296				.contents
297				.get_mut(fold.start_pos)
298				.expect("Instr at pos does not exist")
299				.kind = MIRInstrKind::Assign {
300				left: MutableValue::Reg(reg.clone()),
301				right: DeclareBinding::Value(Value::Mutable(MutableValue::Reg(right))),
302			};
303			block
304				.contents
305				.get_mut(fold.end_pos)
306				.expect("Instr at pos does not exist")
307				.kind = MIRInstrKind::Add {
308				left: MutableValue::Reg(reg),
309				right: Value::Constant(DataTypeContents::Score(ScoreTypeContents::Score(
310					fold.const_val,
311				))),
312			};
313		}
314	}
315
316	for (reg, fold) in overwrite_op {
317		if let Some(right) = fold.right {
318			run_again = true;
319			removed.insert(fold.start_pos);
320			block
321				.contents
322				.get_mut(fold.end_pos)
323				.expect("Instr at pos does not exist")
324				.kind = MIRInstrKind::Assign {
325				left: MutableValue::Reg(reg),
326				right,
327			};
328		}
329	}
330
331	for (_, fold) in stack_peak {
332		if fold.finished && !fold.ops.is_empty() {
333			run_again = true;
334			removed.insert(fold.start_pos);
335			removed.insert(fold.end_pos);
336			for (mut op, pos) in fold.ops.into_iter().zip(fold.op_poses) {
337				op.set_lhs(MutableValue::Reg(fold.original_reg.clone()));
338				block
339					.contents
340					.get_mut(pos)
341					.expect("Instr at pos does not exist")
342					.kind = op.to_instr();
343			}
344		}
345	}
346
347	run_again
348}
349
350/// Simplifies:
351/// let x = 0; if {condition}: x = 1
352/// to:
353/// let x = cond {condition}
354struct IfCondAssign {
355	finished: bool,
356	start_pos: usize,
357	end_pos: usize,
358	/// Whether the pattern starts with x = 0 or x = 1
359	invert: bool,
360	condition: Option<Condition>,
361}
362
363/// Simplifies:
364/// let x = A; x += y
365/// to:
366/// let x = y; x += A
367struct AssignConstAdd {
368	finished: bool,
369	start_pos: usize,
370	end_pos: usize,
371	const_val: i32,
372	right: Option<Identifier>,
373}
374
375/// Simplifies:
376/// let temp = a; a = b; b = temp
377/// to:
378/// let temp = a; swap a, b;
379/// The key for the fold is the temp register
380#[allow(dead_code)]
381struct ManualSwap {
382	finished: bool,
383	pos1: usize,
384	pos2: usize,
385	pos3: usize,
386	left: Option<Identifier>,
387	right: Option<Identifier>,
388}
389
390/// Simplifies:
391/// x o= ..; x = y
392/// to:
393/// x = y
394struct OverwriteOp {
395	finished: bool,
396	start_pos: usize,
397	end_pos: usize,
398	right: Option<DeclareBinding>,
399}
400
401/// Simplifies:
402/// x = y; x o= ..; ...; y = x
403/// to:
404/// y o= ..; ...
405#[derive(Debug)]
406struct StackPeak {
407	finished: bool,
408	start_pos: usize,
409	end_pos: usize,
410	original_reg: Identifier,
411	op_poses: Vec<usize>,
412	ops: Vec<Operation>,
413}