duskphantom_middle/transform/
constant_fold.rs

1// Copyright 2024 Duskphantom Authors
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15// SPDX-License-Identifier: Apache-2.0
16
17use anyhow::Result;
18
19use crate::ir::instruction::downcast_ref;
20use crate::ir::instruction::misc_inst::FCmp;
21use crate::ir::instruction::misc_inst::FCmpOp;
22use crate::ir::instruction::misc_inst::ICmp;
23use crate::ir::instruction::misc_inst::ICmpOp;
24use crate::{
25    ir::{instruction::InstType, Constant, InstPtr, Operand},
26    Program,
27};
28
29use super::Transform;
30
31#[allow(unused)]
32pub fn optimize_program(program: &mut Program) -> Result<bool> {
33    ConstantFold::new(program).run_and_log()
34}
35
36pub struct ConstantFold<'a> {
37    program: &'a mut Program,
38}
39
40impl<'a> Transform for ConstantFold<'a> {
41    fn get_program_mut(&mut self) -> &mut Program {
42        self.program
43    }
44
45    fn name() -> String {
46        "constant_fold".to_string()
47    }
48
49    fn run(&mut self) -> Result<bool> {
50        let mut changed = false;
51        for func in self.program.module.functions.clone().iter() {
52            if func.is_lib() {
53                continue;
54            }
55            for bb in func.dfs_iter() {
56                for inst in bb.iter() {
57                    changed |= self.constant_fold_inst(inst)?;
58                }
59            }
60        }
61        Ok(changed)
62    }
63}
64
65impl<'a> ConstantFold<'a> {
66    pub fn new(program: &'a mut Program) -> Self {
67        Self { program }
68    }
69
70    fn constant_fold_inst(&mut self, mut inst: InstPtr) -> Result<bool> {
71        match inst.get_type() {
72            InstType::Add | InstType::FAdd => {
73                let lhs = inst.get_operand()[0].clone();
74                let rhs = inst.get_operand()[1].clone();
75                if let (Operand::Constant(lhs), Operand::Constant(rhs)) = (lhs, rhs) {
76                    let result = lhs + rhs;
77                    inst.replace_self(&result.into());
78                    return Ok(true);
79                }
80            }
81            InstType::Sub | InstType::FSub => {
82                let lhs = inst.get_operand()[0].clone();
83                let rhs = inst.get_operand()[1].clone();
84                if let (Operand::Constant(lhs), Operand::Constant(rhs)) = (lhs, rhs) {
85                    let result = lhs - rhs;
86                    inst.replace_self(&result.into());
87                    return Ok(true);
88                }
89            }
90            InstType::Mul | InstType::FMul => {
91                let lhs = inst.get_operand()[0].clone();
92                let rhs = inst.get_operand()[1].clone();
93                if let (Operand::Constant(lhs), Operand::Constant(rhs)) = (lhs, rhs) {
94                    let result = lhs * rhs;
95                    inst.replace_self(&result.into());
96                    return Ok(true);
97                }
98            }
99            InstType::UDiv => {
100                let lhs = inst.get_operand()[0].clone();
101                let rhs = inst.get_operand()[1].clone();
102                if let (Operand::Constant(lhs), Operand::Constant(rhs)) = (lhs, rhs) {
103                    let lhs: u32 = lhs.into();
104                    let rhs: u32 = rhs.into();
105                    let result = lhs / rhs;
106                    inst.replace_self(&Operand::Constant(result.into()));
107                    return Ok(true);
108                }
109            }
110            InstType::SDiv | InstType::FDiv => {
111                let lhs = inst.get_operand()[0].clone();
112                let rhs = inst.get_operand()[1].clone();
113                if let (Operand::Constant(lhs), Operand::Constant(rhs)) = (lhs, rhs) {
114                    let result = lhs / rhs;
115                    inst.replace_self(&result.into());
116                    return Ok(true);
117                }
118            }
119            InstType::URem | InstType::SRem => {
120                let lhs = inst.get_operand()[0].clone();
121                let rhs = inst.get_operand()[1].clone();
122                if let (Operand::Constant(lhs), Operand::Constant(rhs)) = (lhs, rhs) {
123                    let result = lhs % rhs;
124                    inst.replace_self(&result.into());
125                    return Ok(true);
126                }
127            }
128            InstType::Shl => {
129                let lhs = inst.get_operand()[0].clone();
130                let rhs = inst.get_operand()[1].clone();
131                if let (Operand::Constant(lhs), Operand::Constant(rhs)) = (lhs, rhs) {
132                    let result = lhs << rhs;
133                    inst.replace_self(&result.into());
134                    return Ok(true);
135                }
136            }
137            InstType::AShr => {
138                let lhs = inst.get_operand()[0].clone();
139                let rhs = inst.get_operand()[1].clone();
140                if let (Operand::Constant(lhs), Operand::Constant(rhs)) = (lhs, rhs) {
141                    let result = lhs >> rhs;
142                    inst.replace_self(&result.into());
143                    return Ok(true);
144                }
145            }
146            InstType::And => {
147                let lhs = inst.get_operand()[0].clone();
148                let rhs = inst.get_operand()[1].clone();
149                if let (Operand::Constant(lhs), Operand::Constant(rhs)) = (lhs, rhs) {
150                    let result = lhs & rhs;
151                    inst.replace_self(&result.into());
152                    return Ok(true);
153                }
154            }
155            InstType::Or => {
156                let lhs = inst.get_operand()[0].clone();
157                let rhs = inst.get_operand()[1].clone();
158                if let (Operand::Constant(lhs), Operand::Constant(rhs)) = (lhs, rhs) {
159                    let result = lhs | rhs;
160                    inst.replace_self(&result.into());
161                    return Ok(true);
162                }
163            }
164            InstType::Xor => {
165                let lhs = inst.get_operand()[0].clone();
166                let rhs = inst.get_operand()[1].clone();
167                if let (Operand::Constant(lhs), Operand::Constant(rhs)) = (lhs, rhs) {
168                    let result = lhs ^ rhs;
169                    inst.replace_self(&result.into());
170                    return Ok(true);
171                }
172            }
173            InstType::ZextTo | InstType::ItoFp | InstType::FpToI => {
174                let src = inst.get_operand()[0].clone();
175                if let Operand::Constant(src) = src {
176                    let result = src.cast(&inst.get_value_type());
177                    inst.replace_self(&result.into());
178                    return Ok(true);
179                }
180            }
181            InstType::SextTo => {
182                let src = inst.get_operand()[0].clone();
183                if let Operand::Constant(Constant::Bool(b)) = src {
184                    let result = if b { -1 } else { 0 };
185                    inst.replace_self(&Operand::Constant(result.into()));
186                    return Ok(true);
187                }
188            }
189            InstType::ICmp => {
190                let lhs = inst.get_operand()[0].clone();
191                let rhs = inst.get_operand()[1].clone();
192                let cmp_inst = downcast_ref::<ICmp>(inst.as_ref().as_ref());
193                if let (Operand::Constant(lhs), Operand::Constant(rhs)) = (lhs, rhs) {
194                    let result = match cmp_inst.op {
195                        ICmpOp::Eq => lhs == rhs,
196                        ICmpOp::Ne => lhs != rhs,
197                        ICmpOp::Slt => lhs < rhs,
198                        ICmpOp::Sle => lhs <= rhs,
199                        ICmpOp::Sgt => lhs > rhs,
200                        ICmpOp::Sge => lhs >= rhs,
201                        ICmpOp::Ult => {
202                            let lhs: u32 = lhs.into();
203                            let rhs: u32 = rhs.into();
204                            lhs < rhs
205                        }
206                        ICmpOp::Ule => {
207                            let lhs: u32 = lhs.into();
208                            let rhs: u32 = rhs.into();
209                            lhs <= rhs
210                        }
211                        ICmpOp::Ugt => {
212                            let lhs: u32 = lhs.into();
213                            let rhs: u32 = rhs.into();
214                            lhs > rhs
215                        }
216                        ICmpOp::Uge => {
217                            let lhs: u32 = lhs.into();
218                            let rhs: u32 = rhs.into();
219                            lhs >= rhs
220                        }
221                    };
222                    inst.replace_self(&Operand::Constant(result.into()));
223                    return Ok(true);
224                }
225            }
226            InstType::FCmp => {
227                let lhs = inst.get_operand()[0].clone();
228                let rhs = inst.get_operand()[1].clone();
229                let cmp_inst = downcast_ref::<FCmp>(inst.as_ref().as_ref());
230                if let (Operand::Constant(lhs), Operand::Constant(rhs)) = (lhs, rhs) {
231                    let result = match cmp_inst.op {
232                        FCmpOp::False => false,
233                        FCmpOp::True => true,
234                        FCmpOp::Oeq => lhs == rhs,
235                        FCmpOp::One => lhs != rhs,
236                        FCmpOp::Olt => lhs < rhs,
237                        FCmpOp::Ole => lhs <= rhs,
238                        FCmpOp::Ogt => lhs > rhs,
239                        FCmpOp::Oge => lhs >= rhs,
240                        FCmpOp::Ueq => {
241                            let lhs: f32 = lhs.into();
242                            let rhs: f32 = rhs.into();
243                            lhs == rhs || (lhs.is_nan() && rhs.is_nan())
244                        }
245                        FCmpOp::Une => {
246                            let lhs: f32 = lhs.into();
247                            let rhs: f32 = rhs.into();
248                            lhs.is_nan() || rhs.is_nan() || lhs != rhs
249                        }
250                        FCmpOp::Ult => {
251                            let lhs: f32 = lhs.into();
252                            let rhs: f32 = rhs.into();
253                            lhs < rhs || (lhs.is_nan() && !rhs.is_nan())
254                        }
255                        FCmpOp::Ule => {
256                            let lhs: f32 = lhs.into();
257                            let rhs: f32 = rhs.into();
258                            lhs <= rhs || (lhs.is_nan() && !rhs.is_nan())
259                        }
260                        FCmpOp::Ugt => {
261                            let lhs: f32 = lhs.into();
262                            let rhs: f32 = rhs.into();
263                            lhs > rhs || (!lhs.is_nan() && rhs.is_nan())
264                        }
265                        FCmpOp::Uge => {
266                            let lhs: f32 = lhs.into();
267                            let rhs: f32 = rhs.into();
268                            lhs >= rhs || (!lhs.is_nan() && rhs.is_nan())
269                        }
270                        _ => todo!(),
271                    };
272                    inst.replace_self(&Operand::Constant(result.into()));
273                    return Ok(true);
274                }
275            }
276            _ => (),
277        }
278        Ok(false)
279    }
280}