duskphantom_middle/transform/
constant_fold.rs1use anyhow::Result;
18
19use crate::ir::instruction::downcast_ref;
20use crate::ir::instruction::misc_inst::FCmp;
21use crate::ir::instruction::misc_inst::FCmpOp;
22use crate::ir::instruction::misc_inst::ICmp;
23use crate::ir::instruction::misc_inst::ICmpOp;
24use crate::{
25 ir::{instruction::InstType, Constant, InstPtr, Operand},
26 Program,
27};
28
29use super::Transform;
30
31#[allow(unused)]
32pub fn optimize_program(program: &mut Program) -> Result<bool> {
33 ConstantFold::new(program).run_and_log()
34}
35
36pub struct ConstantFold<'a> {
37 program: &'a mut Program,
38}
39
40impl<'a> Transform for ConstantFold<'a> {
41 fn get_program_mut(&mut self) -> &mut Program {
42 self.program
43 }
44
45 fn name() -> String {
46 "constant_fold".to_string()
47 }
48
49 fn run(&mut self) -> Result<bool> {
50 let mut changed = false;
51 for func in self.program.module.functions.clone().iter() {
52 if func.is_lib() {
53 continue;
54 }
55 for bb in func.dfs_iter() {
56 for inst in bb.iter() {
57 changed |= self.constant_fold_inst(inst)?;
58 }
59 }
60 }
61 Ok(changed)
62 }
63}
64
65impl<'a> ConstantFold<'a> {
66 pub fn new(program: &'a mut Program) -> Self {
67 Self { program }
68 }
69
70 fn constant_fold_inst(&mut self, mut inst: InstPtr) -> Result<bool> {
71 match inst.get_type() {
72 InstType::Add | InstType::FAdd => {
73 let lhs = inst.get_operand()[0].clone();
74 let rhs = inst.get_operand()[1].clone();
75 if let (Operand::Constant(lhs), Operand::Constant(rhs)) = (lhs, rhs) {
76 let result = lhs + rhs;
77 inst.replace_self(&result.into());
78 return Ok(true);
79 }
80 }
81 InstType::Sub | InstType::FSub => {
82 let lhs = inst.get_operand()[0].clone();
83 let rhs = inst.get_operand()[1].clone();
84 if let (Operand::Constant(lhs), Operand::Constant(rhs)) = (lhs, rhs) {
85 let result = lhs - rhs;
86 inst.replace_self(&result.into());
87 return Ok(true);
88 }
89 }
90 InstType::Mul | InstType::FMul => {
91 let lhs = inst.get_operand()[0].clone();
92 let rhs = inst.get_operand()[1].clone();
93 if let (Operand::Constant(lhs), Operand::Constant(rhs)) = (lhs, rhs) {
94 let result = lhs * rhs;
95 inst.replace_self(&result.into());
96 return Ok(true);
97 }
98 }
99 InstType::UDiv => {
100 let lhs = inst.get_operand()[0].clone();
101 let rhs = inst.get_operand()[1].clone();
102 if let (Operand::Constant(lhs), Operand::Constant(rhs)) = (lhs, rhs) {
103 let lhs: u32 = lhs.into();
104 let rhs: u32 = rhs.into();
105 let result = lhs / rhs;
106 inst.replace_self(&Operand::Constant(result.into()));
107 return Ok(true);
108 }
109 }
110 InstType::SDiv | InstType::FDiv => {
111 let lhs = inst.get_operand()[0].clone();
112 let rhs = inst.get_operand()[1].clone();
113 if let (Operand::Constant(lhs), Operand::Constant(rhs)) = (lhs, rhs) {
114 let result = lhs / rhs;
115 inst.replace_self(&result.into());
116 return Ok(true);
117 }
118 }
119 InstType::URem | InstType::SRem => {
120 let lhs = inst.get_operand()[0].clone();
121 let rhs = inst.get_operand()[1].clone();
122 if let (Operand::Constant(lhs), Operand::Constant(rhs)) = (lhs, rhs) {
123 let result = lhs % rhs;
124 inst.replace_self(&result.into());
125 return Ok(true);
126 }
127 }
128 InstType::Shl => {
129 let lhs = inst.get_operand()[0].clone();
130 let rhs = inst.get_operand()[1].clone();
131 if let (Operand::Constant(lhs), Operand::Constant(rhs)) = (lhs, rhs) {
132 let result = lhs << rhs;
133 inst.replace_self(&result.into());
134 return Ok(true);
135 }
136 }
137 InstType::AShr => {
138 let lhs = inst.get_operand()[0].clone();
139 let rhs = inst.get_operand()[1].clone();
140 if let (Operand::Constant(lhs), Operand::Constant(rhs)) = (lhs, rhs) {
141 let result = lhs >> rhs;
142 inst.replace_self(&result.into());
143 return Ok(true);
144 }
145 }
146 InstType::And => {
147 let lhs = inst.get_operand()[0].clone();
148 let rhs = inst.get_operand()[1].clone();
149 if let (Operand::Constant(lhs), Operand::Constant(rhs)) = (lhs, rhs) {
150 let result = lhs & rhs;
151 inst.replace_self(&result.into());
152 return Ok(true);
153 }
154 }
155 InstType::Or => {
156 let lhs = inst.get_operand()[0].clone();
157 let rhs = inst.get_operand()[1].clone();
158 if let (Operand::Constant(lhs), Operand::Constant(rhs)) = (lhs, rhs) {
159 let result = lhs | rhs;
160 inst.replace_self(&result.into());
161 return Ok(true);
162 }
163 }
164 InstType::Xor => {
165 let lhs = inst.get_operand()[0].clone();
166 let rhs = inst.get_operand()[1].clone();
167 if let (Operand::Constant(lhs), Operand::Constant(rhs)) = (lhs, rhs) {
168 let result = lhs ^ rhs;
169 inst.replace_self(&result.into());
170 return Ok(true);
171 }
172 }
173 InstType::ZextTo | InstType::ItoFp | InstType::FpToI => {
174 let src = inst.get_operand()[0].clone();
175 if let Operand::Constant(src) = src {
176 let result = src.cast(&inst.get_value_type());
177 inst.replace_self(&result.into());
178 return Ok(true);
179 }
180 }
181 InstType::SextTo => {
182 let src = inst.get_operand()[0].clone();
183 if let Operand::Constant(Constant::Bool(b)) = src {
184 let result = if b { -1 } else { 0 };
185 inst.replace_self(&Operand::Constant(result.into()));
186 return Ok(true);
187 }
188 }
189 InstType::ICmp => {
190 let lhs = inst.get_operand()[0].clone();
191 let rhs = inst.get_operand()[1].clone();
192 let cmp_inst = downcast_ref::<ICmp>(inst.as_ref().as_ref());
193 if let (Operand::Constant(lhs), Operand::Constant(rhs)) = (lhs, rhs) {
194 let result = match cmp_inst.op {
195 ICmpOp::Eq => lhs == rhs,
196 ICmpOp::Ne => lhs != rhs,
197 ICmpOp::Slt => lhs < rhs,
198 ICmpOp::Sle => lhs <= rhs,
199 ICmpOp::Sgt => lhs > rhs,
200 ICmpOp::Sge => lhs >= rhs,
201 ICmpOp::Ult => {
202 let lhs: u32 = lhs.into();
203 let rhs: u32 = rhs.into();
204 lhs < rhs
205 }
206 ICmpOp::Ule => {
207 let lhs: u32 = lhs.into();
208 let rhs: u32 = rhs.into();
209 lhs <= rhs
210 }
211 ICmpOp::Ugt => {
212 let lhs: u32 = lhs.into();
213 let rhs: u32 = rhs.into();
214 lhs > rhs
215 }
216 ICmpOp::Uge => {
217 let lhs: u32 = lhs.into();
218 let rhs: u32 = rhs.into();
219 lhs >= rhs
220 }
221 };
222 inst.replace_self(&Operand::Constant(result.into()));
223 return Ok(true);
224 }
225 }
226 InstType::FCmp => {
227 let lhs = inst.get_operand()[0].clone();
228 let rhs = inst.get_operand()[1].clone();
229 let cmp_inst = downcast_ref::<FCmp>(inst.as_ref().as_ref());
230 if let (Operand::Constant(lhs), Operand::Constant(rhs)) = (lhs, rhs) {
231 let result = match cmp_inst.op {
232 FCmpOp::False => false,
233 FCmpOp::True => true,
234 FCmpOp::Oeq => lhs == rhs,
235 FCmpOp::One => lhs != rhs,
236 FCmpOp::Olt => lhs < rhs,
237 FCmpOp::Ole => lhs <= rhs,
238 FCmpOp::Ogt => lhs > rhs,
239 FCmpOp::Oge => lhs >= rhs,
240 FCmpOp::Ueq => {
241 let lhs: f32 = lhs.into();
242 let rhs: f32 = rhs.into();
243 lhs == rhs || (lhs.is_nan() && rhs.is_nan())
244 }
245 FCmpOp::Une => {
246 let lhs: f32 = lhs.into();
247 let rhs: f32 = rhs.into();
248 lhs.is_nan() || rhs.is_nan() || lhs != rhs
249 }
250 FCmpOp::Ult => {
251 let lhs: f32 = lhs.into();
252 let rhs: f32 = rhs.into();
253 lhs < rhs || (lhs.is_nan() && !rhs.is_nan())
254 }
255 FCmpOp::Ule => {
256 let lhs: f32 = lhs.into();
257 let rhs: f32 = rhs.into();
258 lhs <= rhs || (lhs.is_nan() && !rhs.is_nan())
259 }
260 FCmpOp::Ugt => {
261 let lhs: f32 = lhs.into();
262 let rhs: f32 = rhs.into();
263 lhs > rhs || (!lhs.is_nan() && rhs.is_nan())
264 }
265 FCmpOp::Uge => {
266 let lhs: f32 = lhs.into();
267 let rhs: f32 = rhs.into();
268 lhs >= rhs || (!lhs.is_nan() && rhs.is_nan())
269 }
270 _ => todo!(),
271 };
272 inst.replace_self(&Operand::Constant(result.into()));
273 return Ok(true);
274 }
275 }
276 _ => (),
277 }
278 Ok(false)
279 }
280}