duskphantom_middle/transform/
inst_combine.rs

1// Copyright 2024 Duskphantom Authors
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15// SPDX-License-Identifier: Apache-2.0
16
17use std::collections::HashSet;
18
19use anyhow::{anyhow, Context, Result};
20
21use crate::ir::instruction::downcast_ref;
22use crate::{
23    context,
24    ir::{
25        instruction::{
26            memory_op_inst::GetElementPtr,
27            misc_inst::{FCmp, FCmpOp, ICmp, ICmpOp},
28            InstType,
29        },
30        BBPtr, Constant, FunPtr, InstPtr, Operand,
31    },
32    Program,
33};
34
35use super::Transform;
36
37pub fn optimize_program(program: &mut Program) -> Result<bool> {
38    SymbolicEval::new(program).run_and_log()
39}
40
41pub struct SymbolicEval<'a> {
42    program: &'a mut Program,
43    reachable: HashSet<BBPtr>,
44    func: FunPtr,
45}
46
47impl<'a> Transform for SymbolicEval<'a> {
48    fn get_program_mut(&mut self) -> &mut Program {
49        self.program
50    }
51
52    fn name() -> String {
53        "symbolic_eval".to_string()
54    }
55
56    fn run(&mut self) -> Result<bool> {
57        let mut changed = false;
58        for func in self.program.module.functions.clone() {
59            if func.is_lib() {
60                continue;
61            }
62            self.func = func;
63            self.reachable = self.build_reachable_set()?;
64            for bb in func.rpo_iter() {
65                if !self.reachable.contains(&bb) {
66                    continue;
67                }
68                for inst in bb.iter() {
69                    changed |= self.symbolic_eval(inst)?;
70                }
71            }
72        }
73        Ok(changed)
74    }
75}
76
77impl<'a> SymbolicEval<'a> {
78    pub fn new(program: &'a mut Program) -> Self {
79        let func = program.module.functions[0];
80        Self {
81            program,
82            func,
83            reachable: HashSet::new(),
84        }
85    }
86
87    /// Symbolic evaluate instruction to its simplest form.
88    /// It guarantees to simplify any loopless program to its simplest form in one pass,
89    /// so it requires only O(loop_connectedness) calls at most.
90    fn symbolic_eval(&mut self, inst: InstPtr) -> Result<bool> {
91        let mut changed = false;
92        changed |= self.canonicalize_binary(inst)?;
93        changed |= self.constant_fold(inst)?
94            || self.canonicalize_gep(inst)?
95            || self.useless_elim(inst)?
96            || self.inst_combine(inst)?;
97        Ok(changed)
98    }
99
100    /// For commutative instructions, move constant to RHS.
101    /// This does not remove instruction.
102    fn canonicalize_binary(&mut self, mut inst: InstPtr) -> Result<bool> {
103        let inst_type = inst.get_type();
104        match inst_type {
105            InstType::Add | InstType::Mul | InstType::FAdd | InstType::FMul => {
106                let lhs = inst.get_operand()[0].clone();
107                let rhs = inst.get_operand()[1].clone();
108                if lhs.is_const() && !rhs.is_const() {
109                    // Safety: swapping operand does not change use-def chain
110                    unsafe {
111                        let vec = inst.get_operand_mut();
112                        vec.swap(0, 1);
113                        return Ok(true);
114                    }
115                }
116            }
117            _ => (),
118        }
119        Ok(false)
120    }
121
122    /// Canonicalize `gep (gep %ptr, a, 0), b` to `gep %ptr, a, b`.
123    /// If changed, original instruction is removed.
124    fn canonicalize_gep(&mut self, mut inst: InstPtr) -> Result<bool> {
125        let inst_type = inst.get_type();
126        if inst_type == InstType::GetElementPtr && inst.get_operand().len() == 2 {
127            let lhs = inst.get_operand()[0].clone();
128            let rhs = inst.get_operand()[1].clone();
129            if let Operand::Instruction(lhs) = lhs {
130                if lhs.get_type() == InstType::GetElementPtr
131                    && lhs.get_operand().last() == Some(&Operand::Constant(Constant::Int(0)))
132                {
133                    let lhs_lhs = lhs.get_operand()[0].clone();
134                    let mut indexes = lhs.get_operand()[1..].to_vec();
135                    *indexes.last_mut().unwrap() = rhs;
136                    let gep = downcast_ref::<GetElementPtr>(lhs.as_ref().as_ref());
137                    let new_inst = self.program.mem_pool.get_getelementptr(
138                        gep.element_type.clone(),
139                        lhs_lhs,
140                        indexes,
141                    );
142                    inst.insert_after(new_inst);
143                    inst.replace_self(&new_inst.into());
144                    return Ok(true);
145                }
146            }
147        }
148        Ok(false)
149    }
150
151    /// Constant folding.
152    /// If changed, original instruction is removed.
153    fn constant_fold(&mut self, mut inst: InstPtr) -> Result<bool> {
154        let inst_type = inst.get_type();
155        match inst_type {
156            InstType::Add | InstType::FAdd => {
157                let lhs = inst.get_operand()[0].clone();
158                let rhs = inst.get_operand()[1].clone();
159                if let (Operand::Constant(lhs), Operand::Constant(rhs)) = (lhs, rhs) {
160                    let result = lhs + rhs;
161                    inst.replace_self(&result.into());
162                    return Ok(true);
163                }
164            }
165            InstType::Sub | InstType::FSub => {
166                let lhs = inst.get_operand()[0].clone();
167                let rhs = inst.get_operand()[1].clone();
168                if let (Operand::Constant(lhs), Operand::Constant(rhs)) = (lhs, rhs) {
169                    let result = lhs - rhs;
170                    inst.replace_self(&result.into());
171                    return Ok(true);
172                }
173            }
174            InstType::Mul | InstType::FMul => {
175                let lhs = inst.get_operand()[0].clone();
176                let rhs = inst.get_operand()[1].clone();
177                if let (Operand::Constant(lhs), Operand::Constant(rhs)) = (lhs, rhs) {
178                    let result = lhs * rhs;
179                    inst.replace_self(&result.into());
180                    return Ok(true);
181                }
182            }
183            InstType::UDiv => {
184                let lhs = inst.get_operand()[0].clone();
185                let rhs = inst.get_operand()[1].clone();
186                if let (Operand::Constant(lhs), Operand::Constant(rhs)) = (lhs, rhs) {
187                    let lhs: u32 = lhs.into();
188                    let rhs: u32 = rhs.into();
189                    let result = lhs / rhs;
190                    inst.replace_self(&Operand::Constant(result.into()));
191                    return Ok(true);
192                }
193            }
194            InstType::SDiv | InstType::FDiv => {
195                let lhs = inst.get_operand()[0].clone();
196                let rhs = inst.get_operand()[1].clone();
197                if let (Operand::Constant(lhs), Operand::Constant(rhs)) = (lhs, rhs) {
198                    let result = lhs / rhs;
199                    inst.replace_self(&result.into());
200                    return Ok(true);
201                }
202            }
203            InstType::URem | InstType::SRem => {
204                let lhs = inst.get_operand()[0].clone();
205                let rhs = inst.get_operand()[1].clone();
206                if let (Operand::Constant(lhs), Operand::Constant(rhs)) = (lhs, rhs) {
207                    let result = lhs % rhs;
208                    inst.replace_self(&result.into());
209                    return Ok(true);
210                }
211            }
212            InstType::Shl => {
213                let lhs = inst.get_operand()[0].clone();
214                let rhs = inst.get_operand()[1].clone();
215                if let (Operand::Constant(lhs), Operand::Constant(rhs)) = (lhs, rhs) {
216                    let result = lhs << rhs;
217                    inst.replace_self(&result.into());
218                    return Ok(true);
219                }
220            }
221            InstType::AShr => {
222                let lhs = inst.get_operand()[0].clone();
223                let rhs = inst.get_operand()[1].clone();
224                if let (Operand::Constant(lhs), Operand::Constant(rhs)) = (lhs, rhs) {
225                    let result = lhs >> rhs;
226                    inst.replace_self(&result.into());
227                    return Ok(true);
228                }
229            }
230            InstType::And => {
231                let lhs = inst.get_operand()[0].clone();
232                let rhs = inst.get_operand()[1].clone();
233                if let (Operand::Constant(lhs), Operand::Constant(rhs)) = (lhs, rhs) {
234                    let result = lhs & rhs;
235                    inst.replace_self(&result.into());
236                    return Ok(true);
237                }
238            }
239            InstType::Or => {
240                let lhs = inst.get_operand()[0].clone();
241                let rhs = inst.get_operand()[1].clone();
242                if let (Operand::Constant(lhs), Operand::Constant(rhs)) = (lhs, rhs) {
243                    let result = lhs | rhs;
244                    inst.replace_self(&result.into());
245                    return Ok(true);
246                }
247            }
248            InstType::Xor => {
249                let lhs = inst.get_operand()[0].clone();
250                let rhs = inst.get_operand()[1].clone();
251                if let (Operand::Constant(lhs), Operand::Constant(rhs)) = (lhs, rhs) {
252                    let result = lhs ^ rhs;
253                    inst.replace_self(&result.into());
254                    return Ok(true);
255                }
256            }
257            InstType::ZextTo | InstType::ItoFp | InstType::FpToI => {
258                let src = inst.get_operand()[0].clone();
259                if let Operand::Constant(src) = src {
260                    let result = src.cast(&inst.get_value_type());
261                    inst.replace_self(&result.into());
262                    return Ok(true);
263                }
264            }
265            InstType::SextTo => {
266                let src = inst.get_operand()[0].clone();
267                if let Operand::Constant(Constant::Bool(b)) = src {
268                    let result = if b { -1 } else { 0 };
269                    inst.replace_self(&Operand::Constant(result.into()));
270                    return Ok(true);
271                }
272            }
273            InstType::ICmp => {
274                let lhs = inst.get_operand()[0].clone();
275                let rhs = inst.get_operand()[1].clone();
276                let cmp_inst = downcast_ref::<ICmp>(inst.as_ref().as_ref());
277                if let (Operand::Constant(lhs), Operand::Constant(rhs)) = (lhs, rhs) {
278                    let result = match cmp_inst.op {
279                        ICmpOp::Eq => lhs == rhs,
280                        ICmpOp::Ne => lhs != rhs,
281                        ICmpOp::Slt => lhs < rhs,
282                        ICmpOp::Sle => lhs <= rhs,
283                        ICmpOp::Sgt => lhs > rhs,
284                        ICmpOp::Sge => lhs >= rhs,
285                        ICmpOp::Ult => {
286                            let lhs: u32 = lhs.into();
287                            let rhs: u32 = rhs.into();
288                            lhs < rhs
289                        }
290                        ICmpOp::Ule => {
291                            let lhs: u32 = lhs.into();
292                            let rhs: u32 = rhs.into();
293                            lhs <= rhs
294                        }
295                        ICmpOp::Ugt => {
296                            let lhs: u32 = lhs.into();
297                            let rhs: u32 = rhs.into();
298                            lhs > rhs
299                        }
300                        ICmpOp::Uge => {
301                            let lhs: u32 = lhs.into();
302                            let rhs: u32 = rhs.into();
303                            lhs >= rhs
304                        }
305                    };
306                    inst.replace_self(&Operand::Constant(result.into()));
307                    return Ok(true);
308                }
309            }
310            InstType::FCmp => {
311                let lhs = inst.get_operand()[0].clone();
312                let rhs = inst.get_operand()[1].clone();
313                let cmp_inst = downcast_ref::<FCmp>(inst.as_ref().as_ref());
314                if let (Operand::Constant(lhs), Operand::Constant(rhs)) = (lhs, rhs) {
315                    let result = match cmp_inst.op {
316                        FCmpOp::False => false,
317                        FCmpOp::True => true,
318                        FCmpOp::Oeq => lhs == rhs,
319                        FCmpOp::One => lhs != rhs,
320                        FCmpOp::Olt => lhs < rhs,
321                        FCmpOp::Ole => lhs <= rhs,
322                        FCmpOp::Ogt => lhs > rhs,
323                        FCmpOp::Oge => lhs >= rhs,
324                        FCmpOp::Ueq => {
325                            let lhs: f32 = lhs.into();
326                            let rhs: f32 = rhs.into();
327                            lhs == rhs || (lhs.is_nan() && rhs.is_nan())
328                        }
329                        FCmpOp::Une => {
330                            let lhs: f32 = lhs.into();
331                            let rhs: f32 = rhs.into();
332                            lhs.is_nan() || rhs.is_nan() || lhs != rhs
333                        }
334                        FCmpOp::Ult => {
335                            let lhs: f32 = lhs.into();
336                            let rhs: f32 = rhs.into();
337                            lhs < rhs || (lhs.is_nan() && !rhs.is_nan())
338                        }
339                        FCmpOp::Ule => {
340                            let lhs: f32 = lhs.into();
341                            let rhs: f32 = rhs.into();
342                            lhs <= rhs || (lhs.is_nan() && !rhs.is_nan())
343                        }
344                        FCmpOp::Ugt => {
345                            let lhs: f32 = lhs.into();
346                            let rhs: f32 = rhs.into();
347                            lhs > rhs || (!lhs.is_nan() && rhs.is_nan())
348                        }
349                        FCmpOp::Uge => {
350                            let lhs: f32 = lhs.into();
351                            let rhs: f32 = rhs.into();
352                            lhs >= rhs || (!lhs.is_nan() && rhs.is_nan())
353                        }
354                        _ => todo!(),
355                    };
356                    inst.replace_self(&Operand::Constant(result.into()));
357                    return Ok(true);
358                }
359            }
360            _ => (),
361        }
362        Ok(false)
363    }
364
365    /// Useless instruction elimination.
366    /// If changed, original instruction is removed.
367    fn useless_elim(&mut self, mut inst: InstPtr) -> Result<bool> {
368        let inst_type = inst.get_type();
369
370        // We treat `br` as if-else expression, try to simplify it if condition is constant
371        // Not separating this to unreachable block elim because it increases time complexity
372        if inst_type == InstType::Br {
373            let Some(cond) = inst.get_operand().first().cloned() else {
374                return Ok(false);
375            };
376            if let Operand::Constant(Constant::Bool(cond)) = cond {
377                // Rewire basic block and prune unreachable blocks
378                let parent_bb = inst
379                    .get_parent_bb()
380                    .ok_or_else(|| anyhow!("{} should have parent block", inst))
381                    .with_context(|| context!())?;
382                self.remove_edge(parent_bb, cond)?;
383
384                // Replace instruction with unconditional jump
385                let new_inst = self.program.mem_pool.get_br(None);
386                inst.insert_after(new_inst);
387                inst.remove_self();
388                return Ok(true);
389            }
390        }
391
392        // x + 0, x - 0, x * 1, x / 1, x >> 0, x << 0, 0 / x, x * 0, phi (x, x), ...
393        match inst_type {
394            InstType::Add | InstType::Sub => {
395                let lhs = inst.get_operand()[0].clone();
396                let rhs = inst.get_operand()[1].clone();
397                if let Operand::Constant(constant) = rhs {
398                    if constant == Constant::Int(0) {
399                        inst.replace_self(&lhs);
400                        return Ok(true);
401                    }
402                }
403            }
404            InstType::FAdd | InstType::FSub => {
405                let lhs = inst.get_operand()[0].clone();
406                let rhs = inst.get_operand()[1].clone();
407                if let Operand::Constant(constant) = rhs {
408                    if constant == Constant::Float(0.0) {
409                        inst.replace_self(&lhs);
410                        return Ok(true);
411                    }
412                }
413            }
414            InstType::Mul => {
415                let lhs = inst.get_operand()[0].clone();
416                let rhs = inst.get_operand()[1].clone();
417                if let Operand::Constant(Constant::Int(1)) = rhs {
418                    inst.replace_self(&lhs);
419                    return Ok(true);
420                }
421                if let Operand::Constant(Constant::Int(0)) = rhs {
422                    inst.replace_self(&rhs);
423                    return Ok(true);
424                }
425            }
426            InstType::SDiv | InstType::UDiv => {
427                let lhs = inst.get_operand()[0].clone();
428                let rhs = inst.get_operand()[1].clone();
429                if let Operand::Constant(Constant::Int(1)) = rhs {
430                    inst.replace_self(&lhs);
431                    return Ok(true);
432                }
433                if let Operand::Constant(Constant::Int(0)) = lhs {
434                    inst.replace_self(&lhs);
435                    return Ok(true);
436                }
437            }
438            InstType::FMul => {
439                let lhs = inst.get_operand()[0].clone();
440                let rhs = inst.get_operand()[1].clone();
441                if let Operand::Constant(Constant::Float(1.0)) = rhs {
442                    inst.replace_self(&lhs);
443                    return Ok(true);
444                }
445                if let Operand::Constant(Constant::Float(0.0)) = rhs {
446                    inst.replace_self(&rhs);
447                    return Ok(true);
448                }
449            }
450            InstType::FDiv => {
451                let lhs = inst.get_operand()[0].clone();
452                let rhs = inst.get_operand()[1].clone();
453                if let Operand::Constant(Constant::Float(1.0)) = rhs {
454                    inst.replace_self(&lhs);
455                    return Ok(true);
456                }
457                if let Operand::Constant(Constant::Float(0.0)) = lhs {
458                    inst.replace_self(&lhs);
459                    return Ok(true);
460                }
461            }
462            InstType::AShr | InstType::Shl => {
463                let lhs = inst.get_operand()[0].clone();
464                let rhs = inst.get_operand()[1].clone();
465                if let Operand::Constant(rhs) = rhs {
466                    if rhs == Constant::Int(0) {
467                        inst.replace_self(&lhs);
468                        return Ok(true);
469                    }
470                }
471                if let Operand::Constant(lhs) = lhs {
472                    if lhs == Constant::Int(0) {
473                        inst.replace_self(&lhs.into());
474                        return Ok(true);
475                    }
476                }
477            }
478            InstType::Phi => {
479                let first = inst.get_operand()[0].clone();
480                let all_same = inst.get_operand().iter().all(|op| *op == first);
481                if all_same {
482                    inst.replace_self(&first);
483                    return Ok(true);
484                }
485            }
486            _ => (),
487        }
488
489        // x / x, x - x, x + x (to x * 2)
490        match inst_type {
491            InstType::SDiv | InstType::UDiv => {
492                let lhs = inst.get_operand()[0].clone();
493                let rhs = inst.get_operand()[1].clone();
494                if lhs == rhs {
495                    inst.replace_self(&Constant::Int(1).into());
496                    return Ok(true);
497                }
498            }
499            InstType::FDiv => {
500                let lhs = inst.get_operand()[0].clone();
501                let rhs = inst.get_operand()[1].clone();
502                if lhs == rhs {
503                    inst.replace_self(&Constant::Float(1.0).into());
504                    return Ok(true);
505                }
506            }
507            InstType::Sub => {
508                let lhs = inst.get_operand()[0].clone();
509                let rhs = inst.get_operand()[1].clone();
510                if lhs == rhs {
511                    inst.replace_self(&Constant::Int(0).into());
512                    return Ok(true);
513                }
514            }
515            InstType::Add => {
516                let lhs = inst.get_operand()[0].clone();
517                let rhs = inst.get_operand()[1].clone();
518                if lhs == rhs {
519                    let new_inst = self.program.mem_pool.get_mul(lhs, Constant::Int(2).into());
520                    inst.insert_after(new_inst);
521                    inst.replace_self(&new_inst.into());
522                    self.symbolic_eval(new_inst)?;
523                    return Ok(true);
524                }
525            }
526            _ => (),
527        }
528        Ok(false)
529    }
530
531    /// Combine multiple instructions into one.
532    /// If changed, original instruction is removed.
533    fn inst_combine(&mut self, mut inst: InstPtr) -> Result<bool> {
534        let inst_type = inst.get_type();
535
536        // (x * n) + x = x * (n + 1), (x * n) - x = x * (n - 1)
537        match inst_type {
538            InstType::Add | InstType::Sub => {
539                let lhs = inst.get_operand()[0].clone();
540                let rhs = inst.get_operand()[1].clone();
541
542                // Check if "lhs is mul", "rhs is same as lhs_lhs" and "lhs_rhs is int constant"
543                if let Operand::Instruction(lhs) = lhs {
544                    if lhs.get_type() == InstType::Mul {
545                        let lhs_lhs = lhs.get_operand()[0].clone();
546                        let lhs_rhs = lhs.get_operand()[1].clone();
547
548                        if lhs_lhs == rhs {
549                            if let Operand::Constant(Constant::Int(lhs_rhs)) = lhs_rhs {
550                                let new_rhs = if inst_type == InstType::Add {
551                                    lhs_rhs + 1
552                                } else {
553                                    lhs_rhs - 1
554                                };
555                                let new_inst = self
556                                    .program
557                                    .mem_pool
558                                    .get_mul(lhs_lhs, Constant::Int(new_rhs).into());
559                                inst.insert_after(new_inst);
560                                inst.replace_self(&new_inst.into());
561                                self.symbolic_eval(new_inst)?;
562                                return Ok(true);
563                            }
564                        }
565                    }
566                }
567            }
568            _ => (),
569        }
570
571        // x + 1 - 6 -> x - 5, x * 2 * 3 -> x * 6, x / 2 / 3 -> x / 6
572        match inst_type {
573            InstType::Add | InstType::Sub => {
574                let lhs = inst.get_operand()[0].clone();
575                let rhs = inst.get_operand()[1].clone();
576
577                // Check if "rhs is constant" and "lhs is add or sub"
578                if let Operand::Constant(rhs) = rhs {
579                    if let Operand::Instruction(lhs) = lhs {
580                        let lhs_type = lhs.get_type();
581                        if matches!(lhs_type, InstType::Add | InstType::Sub) {
582                            let lhs_lhs = lhs.get_operand()[0].clone();
583                            let lhs_rhs = lhs.get_operand()[1].clone();
584
585                            // Combine inst if "lhs_rhs is constant"
586                            if let Operand::Constant(lhs_rhs) = lhs_rhs {
587                                let new_rhs = lhs_rhs.apply(lhs_type) + rhs.apply(inst_type);
588                                let new_inst =
589                                    self.program.mem_pool.get_add(lhs_lhs, new_rhs.into());
590                                inst.insert_after(new_inst);
591                                inst.replace_self(&new_inst.into());
592                                self.symbolic_eval(new_inst)?;
593                                return Ok(true);
594                            }
595                        }
596                    }
597                }
598            }
599            InstType::Mul => {
600                let lhs = inst.get_operand()[0].clone();
601                let rhs = inst.get_operand()[1].clone();
602
603                // Check if "rhs is constant" and "lhs is mul"
604                if let Operand::Constant(rhs) = rhs {
605                    if let Operand::Instruction(lhs) = lhs {
606                        if lhs.get_type() == InstType::Mul {
607                            let lhs_lhs = lhs.get_operand()[0].clone();
608                            let lhs_rhs = lhs.get_operand()[1].clone();
609
610                            // Combine inst if "lhs_rhs is constant"
611                            if let Operand::Constant(lhs_rhs) = lhs_rhs {
612                                let new_rhs = lhs_rhs * rhs;
613                                let new_inst =
614                                    self.program.mem_pool.get_mul(lhs_lhs, new_rhs.into());
615                                inst.insert_after(new_inst);
616                                inst.replace_self(&new_inst.into());
617                                self.symbolic_eval(new_inst)?;
618                                return Ok(true);
619                            }
620                        }
621                    }
622                }
623            }
624            InstType::SDiv => {
625                let lhs = inst.get_operand()[0].clone();
626                let rhs = inst.get_operand()[1].clone();
627
628                // Check if "rhs is constant" and "lhs is div"
629                if let Operand::Constant(Constant::Int(rhs)) = rhs {
630                    if let Operand::Instruction(lhs) = lhs {
631                        if lhs.get_type() == InstType::SDiv {
632                            let lhs_lhs = lhs.get_operand()[0].clone();
633                            let lhs_rhs = lhs.get_operand()[1].clone();
634
635                            // Combine inst if "lhs_rhs is constant"
636                            if let Operand::Constant(Constant::Int(lhs_rhs)) = lhs_rhs {
637                                let (new_rhs, overflow) = lhs_rhs.overflowing_mul(rhs);
638
639                                // If overflow, instruction result is zero
640                                if overflow {
641                                    inst.replace_self(&Constant::Int(0).into());
642                                    return Ok(true);
643                                }
644
645                                // Otherwise, combine division factors
646                                let new_inst = self
647                                    .program
648                                    .mem_pool
649                                    .get_sdiv(lhs_lhs, Constant::Int(new_rhs).into());
650                                inst.insert_after(new_inst);
651                                inst.replace_self(&new_inst.into());
652                                self.symbolic_eval(new_inst)?;
653                                return Ok(true);
654                            }
655                        }
656                    }
657                }
658            }
659            _ => (),
660        }
661        Ok(false)
662    }
663
664    /// Merge `getelementptr` instruction.
665    /// If changed, original instruction is removed.
666    #[allow(unused)]
667    fn merge_gep(&mut self, mut inst: InstPtr) -> Result<bool> {
668        let inst_type = inst.get_type();
669        if inst_type == InstType::GetElementPtr {
670            let ptr = inst.get_operand()[0].clone();
671            if let Operand::Instruction(ptr) = ptr {
672                if ptr.get_type() == InstType::GetElementPtr {
673                    // Outer GEP: getelementptr ty1, inner, i1, ..., in
674                    // Inner GEP: getelementptr ty2, alloc, j1, ..., jm
675                    // Merged GEP: getelementptr ty2, alloc, j1, ..., jm + i1, ..., in
676                    let m = ptr.get_operand().len() - 1;
677
678                    // Create instruction for jm + i1
679                    let add = self
680                        .program
681                        .mem_pool
682                        .get_add(ptr.get_operand()[m].clone(), inst.get_operand()[1].clone());
683                    inst.insert_before(add);
684
685                    // Create a list of all operands
686                    let operands = [
687                        ptr.get_operand()[1..m].to_vec(),
688                        vec![add.into()],
689                        inst.get_operand()[2..].to_vec(),
690                    ]
691                    .concat();
692
693                    // Create new GEP instruction
694                    let gep = downcast_ref::<GetElementPtr>(ptr.as_ref().as_ref());
695                    let new_inst = self.program.mem_pool.get_getelementptr(
696                        gep.element_type.clone(),
697                        ptr.get_operand()[0].clone(),
698                        operands,
699                    );
700
701                    // Replace outer GEP with new GEP
702                    inst.insert_after(new_inst);
703                    inst.replace_self(&new_inst.into());
704                    self.symbolic_eval(add)?;
705                    self.symbolic_eval(new_inst)?;
706                    return Ok(true);
707                }
708            }
709        }
710        Ok(false)
711    }
712
713    /// Returns the set of all reachable blocks.
714    fn build_reachable_set(&mut self) -> Result<HashSet<BBPtr>> {
715        let mut reachable = HashSet::new();
716        for bb in self.func.dfs_iter() {
717            reachable.insert(bb);
718        }
719        Ok(reachable)
720    }
721
722    /// Remove an edge and remove all unreachable basic blocks.
723    /// TODO: Is there a more efficient implementation?
724    fn remove_edge(&mut self, mut bb: BBPtr, cond: bool) -> Result<()> {
725        // Remove path based on condition
726        if cond {
727            bb.remove_false_bb();
728        } else {
729            bb.remove_true_bb();
730        }
731
732        // Build new reachable set
733        let reachable = self.build_reachable_set()?;
734
735        // Remove all unreachable basic blocks from old reachable set
736        for bb in self.reachable.iter() {
737            if !reachable.contains(bb) {
738                bb.clone().remove_self();
739            }
740        }
741
742        // Update reachable set
743        self.reachable = reachable;
744        Ok(())
745    }
746}