Skip to main content

luaur_compiler/methods/
compiler_try_compile_unrolled_for.rs

1use crate::enums::type_constant_folding::Type;
2use crate::functions::compute_cost::compute_cost;
3use crate::functions::get_trip_count::get_trip_count;
4use crate::functions::model_cost_cost_model::model_cost_ast_node_ast_local_usize_dense_hash_map_ast_expr_call_i32_dense_hash_map_ast_expr_constant;
5use crate::records::compiler::Compiler;
6use crate::records::constant::Constant;
7use crate::records::variable::Variable;
8use luaur_ast::records::ast_stat_for::AstStatFor;
9
10impl Compiler {
11    pub fn try_compile_unrolled_for(
12        &mut self,
13        stat: *mut AstStatFor,
14        threshold_base: i32,
15        threshold_max_boost: i32,
16    ) -> bool {
17        let stat_ref = unsafe { &*stat };
18
19        let one = Constant {
20            r#type: Type::Type_Number,
21            string_length: 0,
22            data: unsafe { core::mem::zeroed() },
23        };
24        let mut one_data = unsafe { core::mem::zeroed::<crate::records::constant::ConstantData>() };
25        unsafe { one_data.value_number = 1.0 };
26        let one = Constant {
27            r#type: Type::Type_Number,
28            string_length: 0,
29            data: one_data,
30        };
31
32        let fromc = self.get_constant(stat_ref.from);
33        let toc = self.get_constant(stat_ref.to);
34        let stepc = if !stat_ref.step.is_null() {
35            self.get_constant(stat_ref.step)
36        } else {
37            one
38        };
39
40        let trip_count = if fromc.r#type == Type::Type_Number
41            && toc.r#type == Type::Type_Number
42            && stepc.r#type == Type::Type_Number
43        {
44            get_trip_count(
45                unsafe { fromc.data.value_number },
46                unsafe { toc.data.value_number },
47                unsafe { stepc.data.value_number },
48            )
49        } else {
50            -1
51        };
52
53        if trip_count < 0 {
54            unsafe {
55                (*self.bytecode)
56                    .add_debug_remark(format_args!("loop unroll failed: invalid iteration count"));
57            }
58            return false;
59        }
60
61        if trip_count > threshold_base {
62            unsafe {
63                (*self.bytecode).add_debug_remark(format_args!(
64                    "loop unroll failed: too many iterations ({})",
65                    trip_count
66                ));
67            }
68            return false;
69        }
70
71        if let Some(lv) = self.variables.find(&stat_ref.var) {
72            if lv.written {
73                unsafe {
74                    (*self.bytecode).add_debug_remark(format_args!(
75                        "loop unroll failed: mutable loop variable"
76                    ));
77                }
78                return false;
79            }
80        }
81
82        let mut var = stat_ref.var;
83        let cost_model = model_cost_ast_node_ast_local_usize_dense_hash_map_ast_expr_call_i32_dense_hash_map_ast_expr_constant(
84            stat_ref.body as *mut luaur_ast::records::ast_node::AstNode,
85            &var,
86            1,
87            unsafe { &*self.builtins_fold },
88            &self.constants,
89        );
90
91        let varc = true;
92        let unrolled_cost = compute_cost(cost_model, &varc as *const bool, 1) * trip_count;
93        let baseline_cost = (compute_cost(cost_model, core::ptr::null(), 0) + 1) * trip_count;
94        let unroll_profit = if unrolled_cost == 0 {
95            threshold_max_boost
96        } else {
97            threshold_max_boost.min(100 * baseline_cost / unrolled_cost)
98        };
99
100        let threshold = threshold_base * unroll_profit / 100;
101
102        if unrolled_cost > threshold {
103            unsafe {
104                (*self.bytecode).add_debug_remark(format_args!(
105                    "loop unroll failed: too expensive (iterations {}, cost {}, profit {:.2}x)",
106                    trip_count,
107                    unrolled_cost,
108                    unroll_profit as f64 / 100.0
109                ));
110            }
111            return false;
112        }
113
114        unsafe {
115            (*self.bytecode).add_debug_remark(format_args!(
116                "loop unroll succeeded (iterations {}, cost {}, profit {:.2}x)",
117                trip_count,
118                unrolled_cost,
119                unroll_profit as f64 / 100.0
120            ));
121        }
122
123        self.compile_unrolled_for(
124            stat,
125            trip_count,
126            unsafe { fromc.data.value_number },
127            unsafe { stepc.data.value_number },
128        );
129        true
130    }
131}