1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
#[cfg(test)]
#[path = "branch_inversion_test.rs"]
mod test;
use cairo_lang_filesystem::ids::SmolStrId;
use cairo_lang_semantic::corelib;
use cairo_lang_utils::Intern;
use salsa::Database;
use crate::ids::FunctionLongId;
use crate::{BlockEnd, Lowered, MatchInfo, Statement, StatementCall};
/// Performs branch inversion optimization on a lowered function.
///
/// The branch inversion optimization finds a match enum whose input is the output of a call to
/// `bool_not_impl`.
/// It swaps the arms of the match enum and changes its input to be the input before the negation.
///
/// This optimization is valid only if all paths leading to the match enum pass through the call to
/// `bool_not_impl`. Therefore, the call to `bool_not_impl` should be in the same block as the match
/// enum. `reorder_statements` can be used to ensure the condition above is met.
///
/// Note: The call to `bool_not_impl` is not deleted as we don't know if its output
/// is used by other statements (or block ending).
pub fn branch_inversion(db: &dyn Database, lowered: &mut Lowered<'_>) {
if lowered.blocks.is_empty() {
return;
}
let bool_not_func_id = FunctionLongId::Semantic(corelib::get_core_function_id(
db,
SmolStrId::from(db, "bool_not_impl"),
vec![],
))
.intern(db);
for block in lowered.blocks.iter_mut() {
if let BlockEnd::Match { info: MatchInfo::Enum(info) } = &mut block.end
&& let Some(negated_condition) = block
.statements
.iter()
.rev()
.filter_map(|stmt| match stmt {
Statement::Call(StatementCall {
function,
inputs,
outputs,
with_coupon: false,
..
}) if function == &bool_not_func_id && outputs[..] == [info.input.var_id] => {
Some(inputs[0])
}
_ => None,
})
.next()
{
info.input = negated_condition;
// Swap arms.
let [false_arm, true_arm] = &mut info.arms[..] else {
panic!("Match on bool should have 2 arms.");
};
std::mem::swap(false_arm, true_arm);
std::mem::swap(&mut false_arm.arm_selector, &mut true_arm.arm_selector);
}
}
}