duskphantom_middle/transform/
inst_combine.rs1use std::collections::HashSet;
18
19use anyhow::{anyhow, Context, Result};
20
21use crate::ir::instruction::downcast_ref;
22use crate::{
23 context,
24 ir::{
25 instruction::{
26 memory_op_inst::GetElementPtr,
27 misc_inst::{FCmp, FCmpOp, ICmp, ICmpOp},
28 InstType,
29 },
30 BBPtr, Constant, FunPtr, InstPtr, Operand,
31 },
32 Program,
33};
34
35use super::Transform;
36
37pub fn optimize_program(program: &mut Program) -> Result<bool> {
38 SymbolicEval::new(program).run_and_log()
39}
40
41pub struct SymbolicEval<'a> {
42 program: &'a mut Program,
43 reachable: HashSet<BBPtr>,
44 func: FunPtr,
45}
46
47impl<'a> Transform for SymbolicEval<'a> {
48 fn get_program_mut(&mut self) -> &mut Program {
49 self.program
50 }
51
52 fn name() -> String {
53 "symbolic_eval".to_string()
54 }
55
56 fn run(&mut self) -> Result<bool> {
57 let mut changed = false;
58 for func in self.program.module.functions.clone() {
59 if func.is_lib() {
60 continue;
61 }
62 self.func = func;
63 self.reachable = self.build_reachable_set()?;
64 for bb in func.rpo_iter() {
65 if !self.reachable.contains(&bb) {
66 continue;
67 }
68 for inst in bb.iter() {
69 changed |= self.symbolic_eval(inst)?;
70 }
71 }
72 }
73 Ok(changed)
74 }
75}
76
77impl<'a> SymbolicEval<'a> {
78 pub fn new(program: &'a mut Program) -> Self {
79 let func = program.module.functions[0];
80 Self {
81 program,
82 func,
83 reachable: HashSet::new(),
84 }
85 }
86
87 fn symbolic_eval(&mut self, inst: InstPtr) -> Result<bool> {
91 let mut changed = false;
92 changed |= self.canonicalize_binary(inst)?;
93 changed |= self.constant_fold(inst)?
94 || self.canonicalize_gep(inst)?
95 || self.useless_elim(inst)?
96 || self.inst_combine(inst)?;
97 Ok(changed)
98 }
99
100 fn canonicalize_binary(&mut self, mut inst: InstPtr) -> Result<bool> {
103 let inst_type = inst.get_type();
104 match inst_type {
105 InstType::Add | InstType::Mul | InstType::FAdd | InstType::FMul => {
106 let lhs = inst.get_operand()[0].clone();
107 let rhs = inst.get_operand()[1].clone();
108 if lhs.is_const() && !rhs.is_const() {
109 unsafe {
111 let vec = inst.get_operand_mut();
112 vec.swap(0, 1);
113 return Ok(true);
114 }
115 }
116 }
117 _ => (),
118 }
119 Ok(false)
120 }
121
122 fn canonicalize_gep(&mut self, mut inst: InstPtr) -> Result<bool> {
125 let inst_type = inst.get_type();
126 if inst_type == InstType::GetElementPtr && inst.get_operand().len() == 2 {
127 let lhs = inst.get_operand()[0].clone();
128 let rhs = inst.get_operand()[1].clone();
129 if let Operand::Instruction(lhs) = lhs {
130 if lhs.get_type() == InstType::GetElementPtr
131 && lhs.get_operand().last() == Some(&Operand::Constant(Constant::Int(0)))
132 {
133 let lhs_lhs = lhs.get_operand()[0].clone();
134 let mut indexes = lhs.get_operand()[1..].to_vec();
135 *indexes.last_mut().unwrap() = rhs;
136 let gep = downcast_ref::<GetElementPtr>(lhs.as_ref().as_ref());
137 let new_inst = self.program.mem_pool.get_getelementptr(
138 gep.element_type.clone(),
139 lhs_lhs,
140 indexes,
141 );
142 inst.insert_after(new_inst);
143 inst.replace_self(&new_inst.into());
144 return Ok(true);
145 }
146 }
147 }
148 Ok(false)
149 }
150
151 fn constant_fold(&mut self, mut inst: InstPtr) -> Result<bool> {
154 let inst_type = inst.get_type();
155 match inst_type {
156 InstType::Add | InstType::FAdd => {
157 let lhs = inst.get_operand()[0].clone();
158 let rhs = inst.get_operand()[1].clone();
159 if let (Operand::Constant(lhs), Operand::Constant(rhs)) = (lhs, rhs) {
160 let result = lhs + rhs;
161 inst.replace_self(&result.into());
162 return Ok(true);
163 }
164 }
165 InstType::Sub | InstType::FSub => {
166 let lhs = inst.get_operand()[0].clone();
167 let rhs = inst.get_operand()[1].clone();
168 if let (Operand::Constant(lhs), Operand::Constant(rhs)) = (lhs, rhs) {
169 let result = lhs - rhs;
170 inst.replace_self(&result.into());
171 return Ok(true);
172 }
173 }
174 InstType::Mul | InstType::FMul => {
175 let lhs = inst.get_operand()[0].clone();
176 let rhs = inst.get_operand()[1].clone();
177 if let (Operand::Constant(lhs), Operand::Constant(rhs)) = (lhs, rhs) {
178 let result = lhs * rhs;
179 inst.replace_self(&result.into());
180 return Ok(true);
181 }
182 }
183 InstType::UDiv => {
184 let lhs = inst.get_operand()[0].clone();
185 let rhs = inst.get_operand()[1].clone();
186 if let (Operand::Constant(lhs), Operand::Constant(rhs)) = (lhs, rhs) {
187 let lhs: u32 = lhs.into();
188 let rhs: u32 = rhs.into();
189 let result = lhs / rhs;
190 inst.replace_self(&Operand::Constant(result.into()));
191 return Ok(true);
192 }
193 }
194 InstType::SDiv | InstType::FDiv => {
195 let lhs = inst.get_operand()[0].clone();
196 let rhs = inst.get_operand()[1].clone();
197 if let (Operand::Constant(lhs), Operand::Constant(rhs)) = (lhs, rhs) {
198 let result = lhs / rhs;
199 inst.replace_self(&result.into());
200 return Ok(true);
201 }
202 }
203 InstType::URem | InstType::SRem => {
204 let lhs = inst.get_operand()[0].clone();
205 let rhs = inst.get_operand()[1].clone();
206 if let (Operand::Constant(lhs), Operand::Constant(rhs)) = (lhs, rhs) {
207 let result = lhs % rhs;
208 inst.replace_self(&result.into());
209 return Ok(true);
210 }
211 }
212 InstType::Shl => {
213 let lhs = inst.get_operand()[0].clone();
214 let rhs = inst.get_operand()[1].clone();
215 if let (Operand::Constant(lhs), Operand::Constant(rhs)) = (lhs, rhs) {
216 let result = lhs << rhs;
217 inst.replace_self(&result.into());
218 return Ok(true);
219 }
220 }
221 InstType::AShr => {
222 let lhs = inst.get_operand()[0].clone();
223 let rhs = inst.get_operand()[1].clone();
224 if let (Operand::Constant(lhs), Operand::Constant(rhs)) = (lhs, rhs) {
225 let result = lhs >> rhs;
226 inst.replace_self(&result.into());
227 return Ok(true);
228 }
229 }
230 InstType::And => {
231 let lhs = inst.get_operand()[0].clone();
232 let rhs = inst.get_operand()[1].clone();
233 if let (Operand::Constant(lhs), Operand::Constant(rhs)) = (lhs, rhs) {
234 let result = lhs & rhs;
235 inst.replace_self(&result.into());
236 return Ok(true);
237 }
238 }
239 InstType::Or => {
240 let lhs = inst.get_operand()[0].clone();
241 let rhs = inst.get_operand()[1].clone();
242 if let (Operand::Constant(lhs), Operand::Constant(rhs)) = (lhs, rhs) {
243 let result = lhs | rhs;
244 inst.replace_self(&result.into());
245 return Ok(true);
246 }
247 }
248 InstType::Xor => {
249 let lhs = inst.get_operand()[0].clone();
250 let rhs = inst.get_operand()[1].clone();
251 if let (Operand::Constant(lhs), Operand::Constant(rhs)) = (lhs, rhs) {
252 let result = lhs ^ rhs;
253 inst.replace_self(&result.into());
254 return Ok(true);
255 }
256 }
257 InstType::ZextTo | InstType::ItoFp | InstType::FpToI => {
258 let src = inst.get_operand()[0].clone();
259 if let Operand::Constant(src) = src {
260 let result = src.cast(&inst.get_value_type());
261 inst.replace_self(&result.into());
262 return Ok(true);
263 }
264 }
265 InstType::SextTo => {
266 let src = inst.get_operand()[0].clone();
267 if let Operand::Constant(Constant::Bool(b)) = src {
268 let result = if b { -1 } else { 0 };
269 inst.replace_self(&Operand::Constant(result.into()));
270 return Ok(true);
271 }
272 }
273 InstType::ICmp => {
274 let lhs = inst.get_operand()[0].clone();
275 let rhs = inst.get_operand()[1].clone();
276 let cmp_inst = downcast_ref::<ICmp>(inst.as_ref().as_ref());
277 if let (Operand::Constant(lhs), Operand::Constant(rhs)) = (lhs, rhs) {
278 let result = match cmp_inst.op {
279 ICmpOp::Eq => lhs == rhs,
280 ICmpOp::Ne => lhs != rhs,
281 ICmpOp::Slt => lhs < rhs,
282 ICmpOp::Sle => lhs <= rhs,
283 ICmpOp::Sgt => lhs > rhs,
284 ICmpOp::Sge => lhs >= rhs,
285 ICmpOp::Ult => {
286 let lhs: u32 = lhs.into();
287 let rhs: u32 = rhs.into();
288 lhs < rhs
289 }
290 ICmpOp::Ule => {
291 let lhs: u32 = lhs.into();
292 let rhs: u32 = rhs.into();
293 lhs <= rhs
294 }
295 ICmpOp::Ugt => {
296 let lhs: u32 = lhs.into();
297 let rhs: u32 = rhs.into();
298 lhs > rhs
299 }
300 ICmpOp::Uge => {
301 let lhs: u32 = lhs.into();
302 let rhs: u32 = rhs.into();
303 lhs >= rhs
304 }
305 };
306 inst.replace_self(&Operand::Constant(result.into()));
307 return Ok(true);
308 }
309 }
310 InstType::FCmp => {
311 let lhs = inst.get_operand()[0].clone();
312 let rhs = inst.get_operand()[1].clone();
313 let cmp_inst = downcast_ref::<FCmp>(inst.as_ref().as_ref());
314 if let (Operand::Constant(lhs), Operand::Constant(rhs)) = (lhs, rhs) {
315 let result = match cmp_inst.op {
316 FCmpOp::False => false,
317 FCmpOp::True => true,
318 FCmpOp::Oeq => lhs == rhs,
319 FCmpOp::One => lhs != rhs,
320 FCmpOp::Olt => lhs < rhs,
321 FCmpOp::Ole => lhs <= rhs,
322 FCmpOp::Ogt => lhs > rhs,
323 FCmpOp::Oge => lhs >= rhs,
324 FCmpOp::Ueq => {
325 let lhs: f32 = lhs.into();
326 let rhs: f32 = rhs.into();
327 lhs == rhs || (lhs.is_nan() && rhs.is_nan())
328 }
329 FCmpOp::Une => {
330 let lhs: f32 = lhs.into();
331 let rhs: f32 = rhs.into();
332 lhs.is_nan() || rhs.is_nan() || lhs != rhs
333 }
334 FCmpOp::Ult => {
335 let lhs: f32 = lhs.into();
336 let rhs: f32 = rhs.into();
337 lhs < rhs || (lhs.is_nan() && !rhs.is_nan())
338 }
339 FCmpOp::Ule => {
340 let lhs: f32 = lhs.into();
341 let rhs: f32 = rhs.into();
342 lhs <= rhs || (lhs.is_nan() && !rhs.is_nan())
343 }
344 FCmpOp::Ugt => {
345 let lhs: f32 = lhs.into();
346 let rhs: f32 = rhs.into();
347 lhs > rhs || (!lhs.is_nan() && rhs.is_nan())
348 }
349 FCmpOp::Uge => {
350 let lhs: f32 = lhs.into();
351 let rhs: f32 = rhs.into();
352 lhs >= rhs || (!lhs.is_nan() && rhs.is_nan())
353 }
354 _ => todo!(),
355 };
356 inst.replace_self(&Operand::Constant(result.into()));
357 return Ok(true);
358 }
359 }
360 _ => (),
361 }
362 Ok(false)
363 }
364
365 fn useless_elim(&mut self, mut inst: InstPtr) -> Result<bool> {
368 let inst_type = inst.get_type();
369
370 if inst_type == InstType::Br {
373 let Some(cond) = inst.get_operand().first().cloned() else {
374 return Ok(false);
375 };
376 if let Operand::Constant(Constant::Bool(cond)) = cond {
377 let parent_bb = inst
379 .get_parent_bb()
380 .ok_or_else(|| anyhow!("{} should have parent block", inst))
381 .with_context(|| context!())?;
382 self.remove_edge(parent_bb, cond)?;
383
384 let new_inst = self.program.mem_pool.get_br(None);
386 inst.insert_after(new_inst);
387 inst.remove_self();
388 return Ok(true);
389 }
390 }
391
392 match inst_type {
394 InstType::Add | InstType::Sub => {
395 let lhs = inst.get_operand()[0].clone();
396 let rhs = inst.get_operand()[1].clone();
397 if let Operand::Constant(constant) = rhs {
398 if constant == Constant::Int(0) {
399 inst.replace_self(&lhs);
400 return Ok(true);
401 }
402 }
403 }
404 InstType::FAdd | InstType::FSub => {
405 let lhs = inst.get_operand()[0].clone();
406 let rhs = inst.get_operand()[1].clone();
407 if let Operand::Constant(constant) = rhs {
408 if constant == Constant::Float(0.0) {
409 inst.replace_self(&lhs);
410 return Ok(true);
411 }
412 }
413 }
414 InstType::Mul => {
415 let lhs = inst.get_operand()[0].clone();
416 let rhs = inst.get_operand()[1].clone();
417 if let Operand::Constant(Constant::Int(1)) = rhs {
418 inst.replace_self(&lhs);
419 return Ok(true);
420 }
421 if let Operand::Constant(Constant::Int(0)) = rhs {
422 inst.replace_self(&rhs);
423 return Ok(true);
424 }
425 }
426 InstType::SDiv | InstType::UDiv => {
427 let lhs = inst.get_operand()[0].clone();
428 let rhs = inst.get_operand()[1].clone();
429 if let Operand::Constant(Constant::Int(1)) = rhs {
430 inst.replace_self(&lhs);
431 return Ok(true);
432 }
433 if let Operand::Constant(Constant::Int(0)) = lhs {
434 inst.replace_self(&lhs);
435 return Ok(true);
436 }
437 }
438 InstType::FMul => {
439 let lhs = inst.get_operand()[0].clone();
440 let rhs = inst.get_operand()[1].clone();
441 if let Operand::Constant(Constant::Float(1.0)) = rhs {
442 inst.replace_self(&lhs);
443 return Ok(true);
444 }
445 if let Operand::Constant(Constant::Float(0.0)) = rhs {
446 inst.replace_self(&rhs);
447 return Ok(true);
448 }
449 }
450 InstType::FDiv => {
451 let lhs = inst.get_operand()[0].clone();
452 let rhs = inst.get_operand()[1].clone();
453 if let Operand::Constant(Constant::Float(1.0)) = rhs {
454 inst.replace_self(&lhs);
455 return Ok(true);
456 }
457 if let Operand::Constant(Constant::Float(0.0)) = lhs {
458 inst.replace_self(&lhs);
459 return Ok(true);
460 }
461 }
462 InstType::AShr | InstType::Shl => {
463 let lhs = inst.get_operand()[0].clone();
464 let rhs = inst.get_operand()[1].clone();
465 if let Operand::Constant(rhs) = rhs {
466 if rhs == Constant::Int(0) {
467 inst.replace_self(&lhs);
468 return Ok(true);
469 }
470 }
471 if let Operand::Constant(lhs) = lhs {
472 if lhs == Constant::Int(0) {
473 inst.replace_self(&lhs.into());
474 return Ok(true);
475 }
476 }
477 }
478 InstType::Phi => {
479 let first = inst.get_operand()[0].clone();
480 let all_same = inst.get_operand().iter().all(|op| *op == first);
481 if all_same {
482 inst.replace_self(&first);
483 return Ok(true);
484 }
485 }
486 _ => (),
487 }
488
489 match inst_type {
491 InstType::SDiv | InstType::UDiv => {
492 let lhs = inst.get_operand()[0].clone();
493 let rhs = inst.get_operand()[1].clone();
494 if lhs == rhs {
495 inst.replace_self(&Constant::Int(1).into());
496 return Ok(true);
497 }
498 }
499 InstType::FDiv => {
500 let lhs = inst.get_operand()[0].clone();
501 let rhs = inst.get_operand()[1].clone();
502 if lhs == rhs {
503 inst.replace_self(&Constant::Float(1.0).into());
504 return Ok(true);
505 }
506 }
507 InstType::Sub => {
508 let lhs = inst.get_operand()[0].clone();
509 let rhs = inst.get_operand()[1].clone();
510 if lhs == rhs {
511 inst.replace_self(&Constant::Int(0).into());
512 return Ok(true);
513 }
514 }
515 InstType::Add => {
516 let lhs = inst.get_operand()[0].clone();
517 let rhs = inst.get_operand()[1].clone();
518 if lhs == rhs {
519 let new_inst = self.program.mem_pool.get_mul(lhs, Constant::Int(2).into());
520 inst.insert_after(new_inst);
521 inst.replace_self(&new_inst.into());
522 self.symbolic_eval(new_inst)?;
523 return Ok(true);
524 }
525 }
526 _ => (),
527 }
528 Ok(false)
529 }
530
531 fn inst_combine(&mut self, mut inst: InstPtr) -> Result<bool> {
534 let inst_type = inst.get_type();
535
536 match inst_type {
538 InstType::Add | InstType::Sub => {
539 let lhs = inst.get_operand()[0].clone();
540 let rhs = inst.get_operand()[1].clone();
541
542 if let Operand::Instruction(lhs) = lhs {
544 if lhs.get_type() == InstType::Mul {
545 let lhs_lhs = lhs.get_operand()[0].clone();
546 let lhs_rhs = lhs.get_operand()[1].clone();
547
548 if lhs_lhs == rhs {
549 if let Operand::Constant(Constant::Int(lhs_rhs)) = lhs_rhs {
550 let new_rhs = if inst_type == InstType::Add {
551 lhs_rhs + 1
552 } else {
553 lhs_rhs - 1
554 };
555 let new_inst = self
556 .program
557 .mem_pool
558 .get_mul(lhs_lhs, Constant::Int(new_rhs).into());
559 inst.insert_after(new_inst);
560 inst.replace_self(&new_inst.into());
561 self.symbolic_eval(new_inst)?;
562 return Ok(true);
563 }
564 }
565 }
566 }
567 }
568 _ => (),
569 }
570
571 match inst_type {
573 InstType::Add | InstType::Sub => {
574 let lhs = inst.get_operand()[0].clone();
575 let rhs = inst.get_operand()[1].clone();
576
577 if let Operand::Constant(rhs) = rhs {
579 if let Operand::Instruction(lhs) = lhs {
580 let lhs_type = lhs.get_type();
581 if matches!(lhs_type, InstType::Add | InstType::Sub) {
582 let lhs_lhs = lhs.get_operand()[0].clone();
583 let lhs_rhs = lhs.get_operand()[1].clone();
584
585 if let Operand::Constant(lhs_rhs) = lhs_rhs {
587 let new_rhs = lhs_rhs.apply(lhs_type) + rhs.apply(inst_type);
588 let new_inst =
589 self.program.mem_pool.get_add(lhs_lhs, new_rhs.into());
590 inst.insert_after(new_inst);
591 inst.replace_self(&new_inst.into());
592 self.symbolic_eval(new_inst)?;
593 return Ok(true);
594 }
595 }
596 }
597 }
598 }
599 InstType::Mul => {
600 let lhs = inst.get_operand()[0].clone();
601 let rhs = inst.get_operand()[1].clone();
602
603 if let Operand::Constant(rhs) = rhs {
605 if let Operand::Instruction(lhs) = lhs {
606 if lhs.get_type() == InstType::Mul {
607 let lhs_lhs = lhs.get_operand()[0].clone();
608 let lhs_rhs = lhs.get_operand()[1].clone();
609
610 if let Operand::Constant(lhs_rhs) = lhs_rhs {
612 let new_rhs = lhs_rhs * rhs;
613 let new_inst =
614 self.program.mem_pool.get_mul(lhs_lhs, new_rhs.into());
615 inst.insert_after(new_inst);
616 inst.replace_self(&new_inst.into());
617 self.symbolic_eval(new_inst)?;
618 return Ok(true);
619 }
620 }
621 }
622 }
623 }
624 InstType::SDiv => {
625 let lhs = inst.get_operand()[0].clone();
626 let rhs = inst.get_operand()[1].clone();
627
628 if let Operand::Constant(Constant::Int(rhs)) = rhs {
630 if let Operand::Instruction(lhs) = lhs {
631 if lhs.get_type() == InstType::SDiv {
632 let lhs_lhs = lhs.get_operand()[0].clone();
633 let lhs_rhs = lhs.get_operand()[1].clone();
634
635 if let Operand::Constant(Constant::Int(lhs_rhs)) = lhs_rhs {
637 let (new_rhs, overflow) = lhs_rhs.overflowing_mul(rhs);
638
639 if overflow {
641 inst.replace_self(&Constant::Int(0).into());
642 return Ok(true);
643 }
644
645 let new_inst = self
647 .program
648 .mem_pool
649 .get_sdiv(lhs_lhs, Constant::Int(new_rhs).into());
650 inst.insert_after(new_inst);
651 inst.replace_self(&new_inst.into());
652 self.symbolic_eval(new_inst)?;
653 return Ok(true);
654 }
655 }
656 }
657 }
658 }
659 _ => (),
660 }
661 Ok(false)
662 }
663
664 #[allow(unused)]
667 fn merge_gep(&mut self, mut inst: InstPtr) -> Result<bool> {
668 let inst_type = inst.get_type();
669 if inst_type == InstType::GetElementPtr {
670 let ptr = inst.get_operand()[0].clone();
671 if let Operand::Instruction(ptr) = ptr {
672 if ptr.get_type() == InstType::GetElementPtr {
673 let m = ptr.get_operand().len() - 1;
677
678 let add = self
680 .program
681 .mem_pool
682 .get_add(ptr.get_operand()[m].clone(), inst.get_operand()[1].clone());
683 inst.insert_before(add);
684
685 let operands = [
687 ptr.get_operand()[1..m].to_vec(),
688 vec![add.into()],
689 inst.get_operand()[2..].to_vec(),
690 ]
691 .concat();
692
693 let gep = downcast_ref::<GetElementPtr>(ptr.as_ref().as_ref());
695 let new_inst = self.program.mem_pool.get_getelementptr(
696 gep.element_type.clone(),
697 ptr.get_operand()[0].clone(),
698 operands,
699 );
700
701 inst.insert_after(new_inst);
703 inst.replace_self(&new_inst.into());
704 self.symbolic_eval(add)?;
705 self.symbolic_eval(new_inst)?;
706 return Ok(true);
707 }
708 }
709 }
710 Ok(false)
711 }
712
713 fn build_reachable_set(&mut self) -> Result<HashSet<BBPtr>> {
715 let mut reachable = HashSet::new();
716 for bb in self.func.dfs_iter() {
717 reachable.insert(bb);
718 }
719 Ok(reachable)
720 }
721
722 fn remove_edge(&mut self, mut bb: BBPtr, cond: bool) -> Result<()> {
725 if cond {
727 bb.remove_false_bb();
728 } else {
729 bb.remove_true_bb();
730 }
731
732 let reachable = self.build_reachable_set()?;
734
735 for bb in self.reachable.iter() {
737 if !reachable.contains(bb) {
738 bb.clone().remove_self();
739 }
740 }
741
742 self.reachable = reachable;
744 Ok(())
745 }
746}