darklua_core/rules/
remove_generalized_iteration.rs

1use crate::nodes::{
2    AssignStatement, BinaryExpression, BinaryOperator, Block, DoStatement, Expression,
3    FieldExpression, FunctionCall, Identifier, IfBranch, IfStatement, LocalAssignStatement, Prefix,
4    Statement, StringExpression, TupleArguments, TypedIdentifier, Variable,
5};
6use crate::process::{DefaultVisitor, NodeProcessor, NodeVisitor};
7use crate::rules::{Context, RuleConfiguration, RuleConfigurationError, RuleProperties};
8
9use super::runtime_identifier::RuntimeIdentifierBuilder;
10use super::{Rule, RuleProcessResult};
11
12const METATABLE_VARIABLE_NAME: &str = "m";
13
14struct Processor {
15    iterator_identifier: String,
16    invariant_identifier: String,
17    control_identifier: String,
18    skip_block_once: bool,
19}
20
21fn get_type_condition(arg: Expression, type_name: &str) -> Box<BinaryExpression> {
22    let type_call = Box::new(FunctionCall::new(
23        Prefix::from_name("type"),
24        TupleArguments::new(vec![arg]).into(),
25        None,
26    ));
27    Box::new(BinaryExpression::new(
28        BinaryOperator::Equal,
29        Expression::Call(type_call),
30        Expression::String(StringExpression::from_value(type_name)),
31    ))
32}
33
34impl Processor {
35    fn process_into_do(&self, block: &mut Block) -> Option<(usize, Statement)> {
36        let block_stmts = block.mutate_statements();
37        for (i, stmt) in block_stmts.iter_mut().enumerate() {
38            if let Statement::GenericFor(generic_for) = stmt {
39                let exps = generic_for.mutate_expressions();
40                if exps.len() == 1 {
41                    let mut stmts: Vec<Statement> = Vec::new();
42                    let iterator_typed_identifier =
43                        TypedIdentifier::new(self.iterator_identifier.as_str());
44                    let iterator_identifier = iterator_typed_identifier.get_identifier().clone();
45
46                    let invariant_typed_identifier =
47                        TypedIdentifier::new(self.invariant_identifier.as_str());
48                    let invariant_identifier = invariant_typed_identifier.get_identifier().clone();
49
50                    let control_typed_identifier =
51                        TypedIdentifier::new(self.control_identifier.as_str());
52                    let control_identifier = control_typed_identifier.get_identifier().clone();
53
54                    let iterator_local_assign = LocalAssignStatement::new(
55                        vec![iterator_typed_identifier],
56                        vec![exps[0].to_owned()],
57                    );
58                    let invar_control_local_assign = LocalAssignStatement::new(
59                        vec![invariant_typed_identifier, control_typed_identifier],
60                        Vec::new(),
61                    );
62
63                    let iterator_exp = Expression::Identifier(iterator_identifier.clone());
64                    exps[0] = iterator_exp.clone();
65                    let invariant_exp = Expression::Identifier(invariant_identifier.clone());
66                    exps.push(invariant_exp);
67                    let control_exp = Expression::Identifier(control_identifier.clone());
68                    exps.push(control_exp);
69
70                    let if_table_condition = get_type_condition(iterator_exp.clone(), "table");
71
72                    let mt_typed_identifier = TypedIdentifier::new(METATABLE_VARIABLE_NAME);
73                    let mt_identifier = mt_typed_identifier.get_identifier().clone();
74
75                    let get_mt_call = FunctionCall::new(
76                        Prefix::from_name("getmetatable"),
77                        TupleArguments::new(vec![iterator_exp.clone()]).into(),
78                        None,
79                    );
80                    let mt_local_assign = LocalAssignStatement::new(
81                        vec![mt_typed_identifier],
82                        vec![get_mt_call.into()],
83                    );
84
85                    let if_mt_table_condition =
86                        get_type_condition(mt_identifier.clone().into(), "table");
87                    let mt_iter = FieldExpression::new(
88                        Prefix::Identifier(mt_identifier),
89                        Identifier::new("__iter"),
90                    );
91                    let if_mt_iter_function_condition =
92                        get_type_condition(mt_iter.clone().into(), "function");
93
94                    let mt_iter_call = FunctionCall::from_prefix(Prefix::Field(Box::new(mt_iter)));
95                    let assign_from_iter = AssignStatement::new(
96                        vec![
97                            Variable::Identifier(iterator_identifier.clone()),
98                            Variable::Identifier(invariant_identifier.clone()),
99                            Variable::Identifier(control_identifier.clone()),
100                        ],
101                        vec![mt_iter_call.into()],
102                    );
103
104                    let pairs_call = FunctionCall::new(
105                        Prefix::from_name("pairs"),
106                        TupleArguments::new(vec![iterator_identifier.clone().into()]).into(),
107                        None,
108                    );
109                    let assign_from_pairs = AssignStatement::new(
110                        vec![
111                            Variable::Identifier(iterator_identifier),
112                            Variable::Identifier(invariant_identifier),
113                            Variable::Identifier(control_identifier),
114                        ],
115                        vec![pairs_call.into()],
116                    );
117
118                    let if_mt_table_block = Block::new(vec![assign_from_iter.into()], None);
119                    let if_not_mt_table_block = Block::new(vec![assign_from_pairs.into()], None);
120                    let if_mt_table_branch = IfBranch::new(
121                        Expression::Binary(Box::new(BinaryExpression::new(
122                            BinaryOperator::And,
123                            Expression::Binary(if_mt_table_condition),
124                            Expression::Binary(if_mt_iter_function_condition),
125                        ))),
126                        if_mt_table_block,
127                    );
128                    let if_mt_table_stmt =
129                        IfStatement::new(vec![if_mt_table_branch], Some(if_not_mt_table_block));
130
131                    let if_table_block =
132                        Block::new(vec![mt_local_assign.into(), if_mt_table_stmt.into()], None);
133                    let if_table_branch =
134                        IfBranch::new(Expression::Binary(if_table_condition), if_table_block);
135                    let if_table_stmt = IfStatement::new(vec![if_table_branch], None);
136
137                    stmts.push(iterator_local_assign.into());
138                    stmts.push(invar_control_local_assign.into());
139                    stmts.push(if_table_stmt.into());
140                    stmts.push(generic_for.clone().into());
141
142                    block_stmts.remove(i);
143
144                    return Some((i, DoStatement::new(Block::new(stmts, None)).into()));
145                }
146            }
147        }
148        None
149    }
150}
151
152impl NodeProcessor for Processor {
153    fn process_block(&mut self, block: &mut Block) {
154        if self.skip_block_once {
155            self.skip_block_once = false;
156            return;
157        }
158        let do_stmt = self.process_into_do(block);
159        if let Some((i, stmt)) = do_stmt {
160            self.skip_block_once = true;
161            block.insert_statement(i, stmt);
162        }
163    }
164}
165
166pub const REMOVE_GENERALIZED_ITERATION_RULE_NAME: &str = "remove_generalized_iteration";
167
168/// A rule that removes generalized iteration.
169#[derive(Debug, PartialEq, Eq)]
170pub struct RemoveGeneralizedIteration {
171    runtime_identifier_format: String,
172}
173
174impl Default for RemoveGeneralizedIteration {
175    fn default() -> Self {
176        Self {
177            runtime_identifier_format: "_DARKLUA_REMOVE_GENERALIZED_ITERATION_{name}{hash}"
178                .to_string(),
179        }
180    }
181}
182
183impl Rule for RemoveGeneralizedIteration {
184    fn process(&self, block: &mut Block, _: &Context) -> RuleProcessResult {
185        let var_builder = RuntimeIdentifierBuilder::new(
186            self.runtime_identifier_format.as_str(),
187            format!("{block:?}").as_bytes(),
188            Some(vec![METATABLE_VARIABLE_NAME.to_string()]),
189        )?;
190        let mut processor = Processor {
191            iterator_identifier: var_builder.build("iter")?,
192            invariant_identifier: var_builder.build("invar")?,
193            control_identifier: var_builder.build("control")?,
194            skip_block_once: false,
195        };
196        DefaultVisitor::visit_block(block, &mut processor);
197        Ok(())
198    }
199}
200
201impl RuleConfiguration for RemoveGeneralizedIteration {
202    fn configure(&mut self, properties: RuleProperties) -> Result<(), RuleConfigurationError> {
203        for (key, value) in properties {
204            match key.as_str() {
205                "runtime_identifier_format" => {
206                    self.runtime_identifier_format = value.expect_string(&key)?;
207                }
208                _ => return Err(RuleConfigurationError::UnexpectedProperty(key)),
209            }
210        }
211
212        Ok(())
213    }
214
215    fn get_name(&self) -> &'static str {
216        REMOVE_GENERALIZED_ITERATION_RULE_NAME
217    }
218
219    fn serialize_to_properties(&self) -> RuleProperties {
220        RuleProperties::new()
221    }
222}
223
224#[cfg(test)]
225mod test {
226    use super::*;
227    use crate::rules::Rule;
228
229    use insta::assert_json_snapshot;
230
231    fn new_rule() -> RemoveGeneralizedIteration {
232        RemoveGeneralizedIteration::default()
233    }
234
235    #[test]
236    fn serialize_default_rule() {
237        let rule: Box<dyn Rule> = Box::new(new_rule());
238
239        assert_json_snapshot!("default_remove_generalized_iteration", rule);
240    }
241
242    #[test]
243    fn configure_with_extra_field_error() {
244        let result = json5::from_str::<Box<dyn Rule>>(
245            r#"{
246            rule: 'remove_generalized_iteration',
247            runtime_identifier_format: '{name}',
248            prop: "something",
249        }"#,
250        );
251        pretty_assertions::assert_eq!(result.unwrap_err().to_string(), "unexpected field 'prop'");
252    }
253}