1use std::collections::VecDeque;
18
19use anyhow::{anyhow, Context, Result};
20
21use crate::{BinaryOp, Decl, Expr, Program, Stmt, Type, TypedIdent, UnaryOp};
22use duskphantom_utils::context;
23use duskphantom_utils::frame_map::FrameMap;
24
25use super::reshape_array::{reshape_array, reshape_const_array};
26
27pub fn optimize_program(program: &mut Program) -> Result<()> {
28 let mut env = FrameMap::new();
29 for decl in program.module.iter_mut() {
30 fold_decl(decl, &mut env, true)?;
31 }
32 Ok(())
33}
34
35fn fold_decl(decl: &mut Decl, env: &mut FrameMap<String, Expr>, is_global: bool) -> Result<()> {
37 match decl {
38 Decl::Const(ty, id, expr) => {
39 *ty = get_folded_type(ty, env)?;
41
42 let mut folded: Expr;
44 match expr {
45 Some(expr) => {
46 folded = get_folded_expr(expr, env, ty)?;
48
49 if let Expr::Array(arr) = folded {
51 folded = reshape_const_array(&mut VecDeque::from(arr), ty)?;
52 }
53 }
54 None => {
55 folded = ty.default_initializer()?;
57 }
58 }
59
60 *expr = Some(folded.clone());
62
63 env.insert(id.clone(), folded);
65 }
66 Decl::Var(ty, _, expr) => {
67 *ty = get_folded_type(ty, env)?;
69
70 if is_global {
72 let mut folded: Expr;
74 match expr {
75 Some(expr) => {
76 folded = get_folded_expr(expr, env, ty)?;
78
79 if let Expr::Array(arr) = folded {
81 folded = reshape_const_array(&mut VecDeque::from(arr), ty)?;
82 }
83 }
84 None => {
85 folded = ty.default_initializer()?;
87 }
88 }
89
90 *expr = Some(folded.clone());
92 } else {
93 if let Some(Expr::Array(arr)) = expr {
95 *expr = Some(reshape_array(&mut VecDeque::from(arr.clone()), ty)?);
96 }
97 }
98 }
99 Decl::Stack(vec) => {
100 for decl in vec {
101 fold_decl(decl, env, is_global)?;
102 }
103 }
104 Decl::Func(ty, _, Some(stmt)) => {
105 *ty = get_folded_type(ty, env)?;
106 fold_stmt(stmt, &mut env.branch())?;
107 }
108 _ => (),
109 }
110 Ok(())
111}
112
113fn fold_stmt(stmt: &mut Stmt, env: &mut FrameMap<String, Expr>) -> Result<()> {
115 match stmt {
116 Stmt::Decl(decl) => fold_decl(decl, env, false)?,
117 Stmt::Block(vec) => {
118 let mut inner_env = env.branch();
119 for stmt in vec {
120 fold_stmt(stmt, &mut inner_env)?;
121 }
122 }
123 Stmt::If(_, a, b) => {
124 fold_stmt(a, env)?;
125 fold_stmt(b, env)?;
126 }
127 Stmt::While(_, a) => fold_stmt(a, env)?,
128 Stmt::DoWhile(a, _) => fold_stmt(a, env)?,
129 _ => (),
130 }
131 Ok(())
132}
133
134impl Expr {
135 pub fn to_i32(&self) -> Result<i32> {
136 match self {
137 Expr::Int(x) => Ok(*x),
138 Expr::Float(x) => Ok(*x as i32),
139 _ => Err(anyhow!("Cannot cast to i32")),
140 }
141 }
142
143 pub fn to_f32(&self) -> Result<f32> {
144 match self {
145 Expr::Int(x) => Ok(*x as f32),
146 Expr::Float(x) => Ok(*x),
147 _ => Err(anyhow!("Cannot cast to f32")),
148 }
149 }
150}
151
152impl From<i32> for Expr {
153 fn from(i: i32) -> Self {
154 Self::Int(i)
155 }
156}
157
158impl From<f32> for Expr {
159 fn from(fl: f32) -> Self {
160 Self::Float(fl)
161 }
162}
163
164impl From<bool> for Expr {
165 fn from(b: bool) -> Self {
166 Self::Bool(b)
167 }
168}
169
170fn get_folded_type(ty: &Type, env: &FrameMap<String, Expr>) -> Result<Type> {
172 match ty {
173 Type::Pointer(ty) => Ok(Type::Pointer(get_folded_type(ty, env)?.into())),
174 Type::Array(element_type, size) => {
175 let size = get_folded_i32(size, env)?;
176 let element_type = get_folded_type(element_type, env)?;
177 Ok(Type::Array(element_type.into(), Expr::Int(size).into()))
178 }
179 Type::Function(ret, params) => {
180 let ret = get_folded_type(ret, env)?;
182
183 let params = params
185 .iter()
186 .map(|ti| get_folded_type(&ti.ty, env).map(|ty| TypedIdent::new(ty, ti.id.clone())))
187 .collect::<Result<_>>()?;
188
189 Ok(Type::Function(ret.into(), params))
191 }
192 _ => Ok(ty.clone()),
193 }
194}
195
196fn get_folded_expr(expr: &Expr, env: &FrameMap<String, Expr>, expr_type: &Type) -> Result<Expr> {
198 match expr_type {
199 Type::Array(element_type, _) => {
200 let arr = get_folded_array(expr, env, element_type)?;
201 Ok(arr)
202 }
203 Type::Int => {
204 let x = get_folded_i32(expr, env)?;
205 Ok(Expr::Int(x))
206 }
207 Type::Float => {
208 let x = get_folded_f32(expr, env)?;
209 Ok(Expr::Float(x))
210 }
211 _ => Err(anyhow!("cannot fold an instance of {:?}", expr_type)).with_context(|| context!()),
212 }
213}
214
215fn get_folded_indexed(
217 expr: &Expr,
218 env: &FrameMap<String, Expr>,
219 mut indexes: Vec<usize>,
220 expr_type: &Type,
221) -> Result<Expr> {
222 if indexes.is_empty() {
224 return get_folded_expr(expr, env, expr_type);
225 }
226
227 match expr {
229 Expr::Var(id) => {
230 let Some(val) = env.get(id) else {
231 return Err(anyhow!("Variable not found"));
232 };
233
234 get_folded_indexed(val, env, indexes, expr_type)
236 }
237 Expr::Array(arr) => {
238 let ix = indexes.pop().unwrap();
240
241 if ix >= arr.len() {
244 return expr_type.default_initializer();
245 }
246
247 get_folded_indexed(&arr[ix], env, indexes, expr_type)
249 }
250 Expr::Index(arr, ix) => {
251 let ix = get_folded_i32(ix, env)?;
252 indexes.push(ix as usize);
253 get_folded_indexed(arr, env, indexes, expr_type)
254 }
255 _ => Err(anyhow!("expr {:?} can't be indexed", expr)).with_context(|| context!()),
256 }
257}
258
259fn get_folded_array(
261 expr: &Expr,
262 env: &FrameMap<String, Expr>,
263 element_type: &Type,
264) -> Result<Expr> {
265 match expr {
266 Expr::Var(id) => {
267 let Some(val) = env.get(id) else {
268 return Err(anyhow!("Variable not found"));
269 };
270
271 get_folded_array(val, env, element_type)
273 }
274 Expr::Array(arr) => arr
275 .iter()
276 .map(|x| get_folded_expr(x, env, element_type))
277 .collect::<Result<_>>()
278 .map(Expr::Array),
279 Expr::Index(arr, ix) => {
280 let ix = get_folded_i32(ix, env)?;
281 get_folded_indexed(
282 arr,
283 env,
284 vec![ix as usize],
285 &Type::Array(element_type.clone().into(), Expr::Int(0).into()),
286 )
287 }
288 _ => get_folded_expr(expr, env, element_type),
289 }
290}
291
292fn get_folded_i32(expr: &Expr, env: &FrameMap<String, Expr>) -> Result<i32> {
294 match expr {
295 Expr::Var(id) => {
296 let Some(val) = env.get(id) else {
297 return Err(anyhow!("Variable not found"));
298 };
299
300 val.to_i32()
302 }
303 Expr::Index(arr, ix) => {
304 let ix = get_folded_i32(ix, env)?;
305 get_folded_indexed(arr, env, vec![ix as usize], &Type::Int)?.to_i32()
306 }
307 Expr::Int(x) => Ok(*x),
308 Expr::Float(x) => Ok(*x as i32),
309 Expr::Bool(x) => Ok(*x as i32),
310 Expr::Unary(op, expr) => {
311 let x = get_folded_i32(expr, env)?;
312 match op {
313 UnaryOp::Neg => Ok(-x),
314 UnaryOp::Pos => Ok(x),
315 UnaryOp::Not => Ok(if x == 0 { 1 } else { 0 }),
316 }
317 }
318 Expr::Binary(head, tail) => {
319 let mut x = get_folded_i32(head, env)?;
320 for (op, expr) in tail {
321 let y = get_folded_i32(expr, env)?;
322 match op {
323 BinaryOp::Add => x += y,
324 BinaryOp::Sub => x -= y,
325 BinaryOp::Mul => x *= y,
326 BinaryOp::Div => x /= y,
327 BinaryOp::Mod => x %= y,
328 BinaryOp::Shr => x >>= y,
329 BinaryOp::Shl => x <<= y,
330 BinaryOp::BitAnd => x &= y,
331 BinaryOp::BitOr => x |= y,
332 BinaryOp::BitXor => x ^= y,
333 BinaryOp::Gt => x = if x > y { 1 } else { 0 },
334 BinaryOp::Lt => x = if x < y { 1 } else { 0 },
335 BinaryOp::Ge => x = if x >= y { 1 } else { 0 },
336 BinaryOp::Le => x = if x <= y { 1 } else { 0 },
337 BinaryOp::Eq => x = if x == y { 1 } else { 0 },
338 BinaryOp::Ne => x = if x != y { 1 } else { 0 },
339 BinaryOp::And => x = if x != 0 && y != 0 { 1 } else { 0 },
340 BinaryOp::Or => x = if x != 0 || y != 0 { 1 } else { 0 },
341 };
342 }
343 Ok(x)
344 }
345 _ => Err(anyhow!("expr {:?} can't be folded to i32", expr)).with_context(|| context!()),
346 }
347}
348
349fn get_folded_f32(expr: &Expr, env: &FrameMap<String, Expr>) -> Result<f32> {
351 match expr {
352 Expr::Var(id) => {
353 let Some(val) = env.get(id) else {
354 return Err(anyhow!("Variable not found"));
355 };
356 val.to_f32()
357 }
358 Expr::Index(arr, ix) => {
359 let ix = get_folded_i32(ix, env)?;
360 get_folded_indexed(arr, env, vec![ix as usize], &Type::Float)?.to_f32()
361 }
362 Expr::Int(x) => Ok(*x as f32),
363 Expr::Float(x) => Ok(*x),
364 Expr::Bool(x) => Ok(*x as i32 as f32),
365 Expr::Unary(op, expr) => {
366 let x = get_folded_f32(expr, env)?;
367 match op {
368 UnaryOp::Neg => Ok(-x),
369 UnaryOp::Pos => Ok(x),
370 UnaryOp::Not => Ok(if x == 0.0 { 1.0 } else { 0.0 }),
371 }
372 }
373 Expr::Binary(head, tail) => {
374 let mut x = get_folded_f32(head, env)?;
375 for (op, expr) in tail {
376 let y = get_folded_f32(expr, env)?;
377 match op {
378 BinaryOp::Add => x += y,
379 BinaryOp::Sub => x -= y,
380 BinaryOp::Mul => x *= y,
381 BinaryOp::Div => x /= y,
382 BinaryOp::Mod => x %= y,
383 BinaryOp::Shr => return Err(anyhow!("Cannot shift float")),
384 BinaryOp::Shl => return Err(anyhow!("Cannot shift float")),
385 BinaryOp::BitAnd => return Err(anyhow!("Cannot bitwise and float")),
386 BinaryOp::BitOr => return Err(anyhow!("Cannot bitwise or float")),
387 BinaryOp::BitXor => return Err(anyhow!("Cannot bitwise xor float")),
388 BinaryOp::Gt => x = if x > y { 1.0 } else { 0.0 },
389 BinaryOp::Lt => x = if x < y { 1.0 } else { 0.0 },
390 BinaryOp::Ge => x = if x >= y { 1.0 } else { 0.0 },
391 BinaryOp::Le => x = if x <= y { 1.0 } else { 0.0 },
392 BinaryOp::Eq => x = if x == y { 1.0 } else { 0.0 },
393 BinaryOp::Ne => x = if x != y { 1.0 } else { 0.0 },
394 BinaryOp::And => x = if x != 0.0 && y != 0.0 { 1.0 } else { 0.0 },
395 BinaryOp::Or => x = if x != 0.0 || y != 0.0 { 1.0 } else { 0.0 },
396 };
397 }
398 Ok(x)
399 }
400 _ => Err(anyhow!("expr {:?} can't be folded to f32", expr)).with_context(|| context!()),
401 }
402}