duskphantom_frontend/transform/
constant_fold.rs

1// Copyright 2024 Duskphantom Authors
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15// SPDX-License-Identifier: Apache-2.0
16
17use std::collections::VecDeque;
18
19use anyhow::{anyhow, Context, Result};
20
21use crate::{BinaryOp, Decl, Expr, Program, Stmt, Type, TypedIdent, UnaryOp};
22use duskphantom_utils::context;
23use duskphantom_utils::frame_map::FrameMap;
24
25use super::reshape_array::{reshape_array, reshape_const_array};
26
27pub fn optimize_program(program: &mut Program) -> Result<()> {
28    let mut env = FrameMap::new();
29    for decl in program.module.iter_mut() {
30        fold_decl(decl, &mut env, true)?;
31    }
32    Ok(())
33}
34
35/// Fold constant expression in declaration into constant.
36fn fold_decl(decl: &mut Decl, env: &mut FrameMap<String, Expr>, is_global: bool) -> Result<()> {
37    match decl {
38        Decl::Const(ty, id, expr) => {
39            // Fold type
40            *ty = get_folded_type(ty, env)?;
41
42            // Calculate folded initializer
43            let mut folded: Expr;
44            match expr {
45                Some(expr) => {
46                    // Calculate from given initializer
47                    folded = get_folded_expr(expr, env, ty)?;
48
49                    // Constant array can be malformed, reshape it
50                    if let Expr::Array(arr) = folded {
51                        folded = reshape_const_array(&mut VecDeque::from(arr), ty)?;
52                    }
53                }
54                None => {
55                    // Use default initializer
56                    folded = ty.default_initializer()?;
57                }
58            }
59
60            // Update expression to folded
61            *expr = Some(folded.clone());
62
63            // Insert folded expression to environment
64            env.insert(id.clone(), folded);
65        }
66        Decl::Var(ty, _, expr) => {
67            // Fold type
68            *ty = get_folded_type(ty, env)?;
69
70            // If variable is global, initializer should be constant
71            if is_global {
72                // Calculate folded initializer
73                let mut folded: Expr;
74                match expr {
75                    Some(expr) => {
76                        // Calculate from given initializer
77                        folded = get_folded_expr(expr, env, ty)?;
78
79                        // Constant array can be malformed, reshape it
80                        if let Expr::Array(arr) = folded {
81                            folded = reshape_const_array(&mut VecDeque::from(arr), ty)?;
82                        }
83                    }
84                    None => {
85                        // Use default initializer
86                        folded = ty.default_initializer()?;
87                    }
88                }
89
90                // Update expression to folded
91                *expr = Some(folded.clone());
92            } else {
93                // Value array can be malformed, reshape it
94                if let Some(Expr::Array(arr)) = expr {
95                    *expr = Some(reshape_array(&mut VecDeque::from(arr.clone()), ty)?);
96                }
97            }
98        }
99        Decl::Stack(vec) => {
100            for decl in vec {
101                fold_decl(decl, env, is_global)?;
102            }
103        }
104        Decl::Func(ty, _, Some(stmt)) => {
105            *ty = get_folded_type(ty, env)?;
106            fold_stmt(stmt, &mut env.branch())?;
107        }
108        _ => (),
109    }
110    Ok(())
111}
112
113/// Fold constant expression in statement into constant.
114fn fold_stmt(stmt: &mut Stmt, env: &mut FrameMap<String, Expr>) -> Result<()> {
115    match stmt {
116        Stmt::Decl(decl) => fold_decl(decl, env, false)?,
117        Stmt::Block(vec) => {
118            let mut inner_env = env.branch();
119            for stmt in vec {
120                fold_stmt(stmt, &mut inner_env)?;
121            }
122        }
123        Stmt::If(_, a, b) => {
124            fold_stmt(a, env)?;
125            fold_stmt(b, env)?;
126        }
127        Stmt::While(_, a) => fold_stmt(a, env)?,
128        Stmt::DoWhile(a, _) => fold_stmt(a, env)?,
129        _ => (),
130    }
131    Ok(())
132}
133
134impl Expr {
135    pub fn to_i32(&self) -> Result<i32> {
136        match self {
137            Expr::Int(x) => Ok(*x),
138            Expr::Float(x) => Ok(*x as i32),
139            _ => Err(anyhow!("Cannot cast to i32")),
140        }
141    }
142
143    pub fn to_f32(&self) -> Result<f32> {
144        match self {
145            Expr::Int(x) => Ok(*x as f32),
146            Expr::Float(x) => Ok(*x),
147            _ => Err(anyhow!("Cannot cast to f32")),
148        }
149    }
150}
151
152impl From<i32> for Expr {
153    fn from(i: i32) -> Self {
154        Self::Int(i)
155    }
156}
157
158impl From<f32> for Expr {
159    fn from(fl: f32) -> Self {
160        Self::Float(fl)
161    }
162}
163
164impl From<bool> for Expr {
165    fn from(b: bool) -> Self {
166        Self::Bool(b)
167    }
168}
169
170/// Fold a type to constant.
171fn get_folded_type(ty: &Type, env: &FrameMap<String, Expr>) -> Result<Type> {
172    match ty {
173        Type::Pointer(ty) => Ok(Type::Pointer(get_folded_type(ty, env)?.into())),
174        Type::Array(element_type, size) => {
175            let size = get_folded_i32(size, env)?;
176            let element_type = get_folded_type(element_type, env)?;
177            Ok(Type::Array(element_type.into(), Expr::Int(size).into()))
178        }
179        Type::Function(ret, params) => {
180            // Fold return type
181            let ret = get_folded_type(ret, env)?;
182
183            // For each param, fold the type it contains
184            let params = params
185                .iter()
186                .map(|ti| get_folded_type(&ti.ty, env).map(|ty| TypedIdent::new(ty, ti.id.clone())))
187                .collect::<Result<_>>()?;
188
189            // Reconstruct function type
190            Ok(Type::Function(ret.into(), params))
191        }
192        _ => Ok(ty.clone()),
193    }
194}
195
196/// Fold an expression to constant.
197fn get_folded_expr(expr: &Expr, env: &FrameMap<String, Expr>, expr_type: &Type) -> Result<Expr> {
198    match expr_type {
199        Type::Array(element_type, _) => {
200            let arr = get_folded_array(expr, env, element_type)?;
201            Ok(arr)
202        }
203        Type::Int => {
204            let x = get_folded_i32(expr, env)?;
205            Ok(Expr::Int(x))
206        }
207        Type::Float => {
208            let x = get_folded_f32(expr, env)?;
209            Ok(Expr::Float(x))
210        }
211        _ => Err(anyhow!("cannot fold an instance of {:?}", expr_type)).with_context(|| context!()),
212    }
213}
214
215/// Fold an indexed expression to constant.
216fn get_folded_indexed(
217    expr: &Expr,
218    env: &FrameMap<String, Expr>,
219    mut indexes: Vec<usize>,
220    expr_type: &Type,
221) -> Result<Expr> {
222    // If not indexed, fallback to regular expression fold
223    if indexes.is_empty() {
224        return get_folded_expr(expr, env, expr_type);
225    }
226
227    // Expression is indexed
228    match expr {
229        Expr::Var(id) => {
230            let Some(val) = env.get(id) else {
231                return Err(anyhow!("Variable not found"));
232            };
233
234            // Although val is already folded, we still need to handle the indexes
235            get_folded_indexed(val, env, indexes, expr_type)
236        }
237        Expr::Array(arr) => {
238            // Get index
239            let ix = indexes.pop().unwrap();
240
241            // Get default initializer if index is out of bounds
242            // This makes `int x[N] = {}; x[n]` default initializer instead of poison value
243            if ix >= arr.len() {
244                return expr_type.default_initializer();
245            }
246
247            // Index unfolded array and then fold the result, to save computation
248            get_folded_indexed(&arr[ix], env, indexes, expr_type)
249        }
250        Expr::Index(arr, ix) => {
251            let ix = get_folded_i32(ix, env)?;
252            indexes.push(ix as usize);
253            get_folded_indexed(arr, env, indexes, expr_type)
254        }
255        _ => Err(anyhow!("expr {:?} can't be indexed", expr)).with_context(|| context!()),
256    }
257}
258
259/// Fold an array to constant.
260fn get_folded_array(
261    expr: &Expr,
262    env: &FrameMap<String, Expr>,
263    element_type: &Type,
264) -> Result<Expr> {
265    match expr {
266        Expr::Var(id) => {
267            let Some(val) = env.get(id) else {
268                return Err(anyhow!("Variable not found"));
269            };
270
271            // Although val is already folded, we still need to handle the type
272            get_folded_array(val, env, element_type)
273        }
274        Expr::Array(arr) => arr
275            .iter()
276            .map(|x| get_folded_expr(x, env, element_type))
277            .collect::<Result<_>>()
278            .map(Expr::Array),
279        Expr::Index(arr, ix) => {
280            let ix = get_folded_i32(ix, env)?;
281            get_folded_indexed(
282                arr,
283                env,
284                vec![ix as usize],
285                &Type::Array(element_type.clone().into(), Expr::Int(0).into()),
286            )
287        }
288        _ => get_folded_expr(expr, env, element_type),
289    }
290}
291
292/// Fold an i32 to constant.
293fn get_folded_i32(expr: &Expr, env: &FrameMap<String, Expr>) -> Result<i32> {
294    match expr {
295        Expr::Var(id) => {
296            let Some(val) = env.get(id) else {
297                return Err(anyhow!("Variable not found"));
298            };
299
300            // Value in environment is already folded, no need to fold again
301            val.to_i32()
302        }
303        Expr::Index(arr, ix) => {
304            let ix = get_folded_i32(ix, env)?;
305            get_folded_indexed(arr, env, vec![ix as usize], &Type::Int)?.to_i32()
306        }
307        Expr::Int(x) => Ok(*x),
308        Expr::Float(x) => Ok(*x as i32),
309        Expr::Bool(x) => Ok(*x as i32),
310        Expr::Unary(op, expr) => {
311            let x = get_folded_i32(expr, env)?;
312            match op {
313                UnaryOp::Neg => Ok(-x),
314                UnaryOp::Pos => Ok(x),
315                UnaryOp::Not => Ok(if x == 0 { 1 } else { 0 }),
316            }
317        }
318        Expr::Binary(head, tail) => {
319            let mut x = get_folded_i32(head, env)?;
320            for (op, expr) in tail {
321                let y = get_folded_i32(expr, env)?;
322                match op {
323                    BinaryOp::Add => x += y,
324                    BinaryOp::Sub => x -= y,
325                    BinaryOp::Mul => x *= y,
326                    BinaryOp::Div => x /= y,
327                    BinaryOp::Mod => x %= y,
328                    BinaryOp::Shr => x >>= y,
329                    BinaryOp::Shl => x <<= y,
330                    BinaryOp::BitAnd => x &= y,
331                    BinaryOp::BitOr => x |= y,
332                    BinaryOp::BitXor => x ^= y,
333                    BinaryOp::Gt => x = if x > y { 1 } else { 0 },
334                    BinaryOp::Lt => x = if x < y { 1 } else { 0 },
335                    BinaryOp::Ge => x = if x >= y { 1 } else { 0 },
336                    BinaryOp::Le => x = if x <= y { 1 } else { 0 },
337                    BinaryOp::Eq => x = if x == y { 1 } else { 0 },
338                    BinaryOp::Ne => x = if x != y { 1 } else { 0 },
339                    BinaryOp::And => x = if x != 0 && y != 0 { 1 } else { 0 },
340                    BinaryOp::Or => x = if x != 0 || y != 0 { 1 } else { 0 },
341                };
342            }
343            Ok(x)
344        }
345        _ => Err(anyhow!("expr {:?} can't be folded to i32", expr)).with_context(|| context!()),
346    }
347}
348
349/// Fold an f32 to constant.
350fn get_folded_f32(expr: &Expr, env: &FrameMap<String, Expr>) -> Result<f32> {
351    match expr {
352        Expr::Var(id) => {
353            let Some(val) = env.get(id) else {
354                return Err(anyhow!("Variable not found"));
355            };
356            val.to_f32()
357        }
358        Expr::Index(arr, ix) => {
359            let ix = get_folded_i32(ix, env)?;
360            get_folded_indexed(arr, env, vec![ix as usize], &Type::Float)?.to_f32()
361        }
362        Expr::Int(x) => Ok(*x as f32),
363        Expr::Float(x) => Ok(*x),
364        Expr::Bool(x) => Ok(*x as i32 as f32),
365        Expr::Unary(op, expr) => {
366            let x = get_folded_f32(expr, env)?;
367            match op {
368                UnaryOp::Neg => Ok(-x),
369                UnaryOp::Pos => Ok(x),
370                UnaryOp::Not => Ok(if x == 0.0 { 1.0 } else { 0.0 }),
371            }
372        }
373        Expr::Binary(head, tail) => {
374            let mut x = get_folded_f32(head, env)?;
375            for (op, expr) in tail {
376                let y = get_folded_f32(expr, env)?;
377                match op {
378                    BinaryOp::Add => x += y,
379                    BinaryOp::Sub => x -= y,
380                    BinaryOp::Mul => x *= y,
381                    BinaryOp::Div => x /= y,
382                    BinaryOp::Mod => x %= y,
383                    BinaryOp::Shr => return Err(anyhow!("Cannot shift float")),
384                    BinaryOp::Shl => return Err(anyhow!("Cannot shift float")),
385                    BinaryOp::BitAnd => return Err(anyhow!("Cannot bitwise and float")),
386                    BinaryOp::BitOr => return Err(anyhow!("Cannot bitwise or float")),
387                    BinaryOp::BitXor => return Err(anyhow!("Cannot bitwise xor float")),
388                    BinaryOp::Gt => x = if x > y { 1.0 } else { 0.0 },
389                    BinaryOp::Lt => x = if x < y { 1.0 } else { 0.0 },
390                    BinaryOp::Ge => x = if x >= y { 1.0 } else { 0.0 },
391                    BinaryOp::Le => x = if x <= y { 1.0 } else { 0.0 },
392                    BinaryOp::Eq => x = if x == y { 1.0 } else { 0.0 },
393                    BinaryOp::Ne => x = if x != y { 1.0 } else { 0.0 },
394                    BinaryOp::And => x = if x != 0.0 && y != 0.0 { 1.0 } else { 0.0 },
395                    BinaryOp::Or => x = if x != 0.0 || y != 0.0 { 1.0 } else { 0.0 },
396                };
397            }
398            Ok(x)
399        }
400        _ => Err(anyhow!("expr {:?} can't be folded to f32", expr)).with_context(|| context!()),
401    }
402}