1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
use cairo_lang_debug::DebugWithDb;
use cairo_lang_diagnostics::Maybe;
use cairo_lang_semantic as semantic;
use cairo_lang_semantic::corelib;
use cairo_lang_utils::extract_matches;
use num_traits::Zero;
use semantic::ExprFunctionCallArg;

use super::block_builder::{BlockBuilder, SealedBlockBuilder};
use super::context::{LoweredExpr, LoweringContext, LoweringFlowError, LoweringResult};
use super::lowered_expr_to_block_scope_end;
use crate::ids::{LocationId, SemanticFunctionIdEx};
use crate::lower::context::VarRequest;
use crate::lower::{
    create_subscope_with_bound_refs, generators, lower_block, lower_expr_to_var_usage,
};
use crate::{MatchArm, MatchEnumInfo, MatchExternInfo, MatchInfo};

#[allow(dead_code)]
enum IfCondition {
    BoolExpr(semantic::ExprId),
    Eq(semantic::ExprId, semantic::ExprId),
}

/// Analyzes the condition of an if statement into an [IfCondition] tree, to allow different
/// optimizations.
// TODO(lior): Make it an actual tree (handling && and ||).
fn analyze_condition(ctx: &LoweringContext<'_, '_>, expr_id: semantic::ExprId) -> IfCondition {
    let expr = &ctx.function_body.exprs[expr_id];
    let semantic::Expr::FunctionCall(function_call) = expr else {
        return IfCondition::BoolExpr(expr_id);
    };
    if function_call.function != corelib::felt252_eq(ctx.db.upcast()) {
        return IfCondition::BoolExpr(expr_id);
    };
    let [expr_a, expr_b] = &function_call.args[..] else {
        return IfCondition::BoolExpr(expr_id);
    };
    let ExprFunctionCallArg::Value(expr_a) = expr_a else {
        return IfCondition::BoolExpr(expr_id);
    };
    let ExprFunctionCallArg::Value(expr_b) = expr_b else {
        return IfCondition::BoolExpr(expr_id);
    };
    let expr_a = &ctx.function_body.exprs[*expr_a];
    let expr_b = &ctx.function_body.exprs[*expr_b];
    let semantic::Expr::Snapshot(expr_a) = expr_a else {
        return IfCondition::BoolExpr(expr_id);
    };
    let semantic::Expr::Snapshot(expr_b) = expr_b else {
        return IfCondition::BoolExpr(expr_id);
    };
    IfCondition::Eq(expr_a.inner, expr_b.inner)
}

fn is_zero(ctx: &LoweringContext<'_, '_>, expr_id: semantic::ExprId) -> bool {
    let expr = &ctx.function_body.exprs[expr_id];
    matches!(expr, semantic::Expr::Literal(literal) if literal.value.is_zero())
}

/// Lowers an expression of type [semantic::ExprIf].
pub fn lower_expr_if(
    ctx: &mut LoweringContext<'_, '_>,
    builder: &mut BlockBuilder,
    expr: &semantic::ExprIf,
) -> LoweringResult<LoweredExpr> {
    match analyze_condition(ctx, expr.condition) {
        IfCondition::BoolExpr(_) => lower_expr_if_bool(ctx, builder, expr),
        IfCondition::Eq(expr_a, expr_b) => lower_expr_if_eq(ctx, builder, expr, expr_a, expr_b),
    }
}

/// Lowers an expression of type [semantic::ExprIf], for the case of [IfCondition::BoolExpr].
pub fn lower_expr_if_bool(
    ctx: &mut LoweringContext<'_, '_>,
    builder: &mut BlockBuilder,
    expr: &semantic::ExprIf,
) -> LoweringResult<LoweredExpr> {
    log::trace!("Lowering a boolean if expression: {:?}", expr.debug(&ctx.expr_formatter));
    // The condition cannot be unit.
    let condition = lower_expr_to_var_usage(ctx, builder, expr.condition)?;
    let semantic_db = ctx.db.upcast();
    let unit_ty = corelib::unit_ty(semantic_db);
    let if_location = ctx.get_location(expr.stable_ptr.untyped());

    // Main block.
    let subscope_main = create_subscope_with_bound_refs(ctx, builder);
    let block_main_id = subscope_main.block_id;
    let main_block =
        extract_matches!(&ctx.function_body.exprs[expr.if_block], semantic::Expr::Block).clone();
    let main_block_var_id = ctx.new_var(VarRequest {
        ty: unit_ty,
        location: ctx.get_location(main_block.stable_ptr.untyped()),
    });
    let block_main =
        lower_block(ctx, subscope_main, &main_block).map_err(LoweringFlowError::Failed)?;

    // Else block.
    let subscope_else = create_subscope_with_bound_refs(ctx, builder);
    let block_else_id = subscope_else.block_id;

    let else_block_input_var_id = ctx.new_var(VarRequest { ty: unit_ty, location: if_location });
    let block_else = lower_optional_else_block(ctx, subscope_else, expr.else_block, if_location)
        .map_err(LoweringFlowError::Failed)?;

    let match_info = MatchInfo::Enum(MatchEnumInfo {
        concrete_enum_id: corelib::core_bool_enum(semantic_db),
        input: condition,
        arms: vec![
            MatchArm {
                variant_id: corelib::false_variant(semantic_db),
                block_id: block_else_id,
                var_ids: vec![else_block_input_var_id],
            },
            MatchArm {
                variant_id: corelib::true_variant(semantic_db),
                block_id: block_main_id,
                var_ids: vec![main_block_var_id],
            },
        ],
        location: if_location,
    });
    builder.merge_and_end_with_match(ctx, match_info, vec![block_main, block_else], if_location)
}

/// Lowers an expression of type [semantic::ExprIf], for the case of [IfCondition::Eq].
pub fn lower_expr_if_eq(
    ctx: &mut LoweringContext<'_, '_>,
    builder: &mut BlockBuilder,
    expr: &semantic::ExprIf,
    expr_a: semantic::ExprId,
    expr_b: semantic::ExprId,
) -> LoweringResult<LoweredExpr> {
    log::trace!(
        "Started lowering of an if-eq-zero expression: {:?}",
        expr.debug(&ctx.expr_formatter)
    );
    let if_location = ctx.get_location(expr.stable_ptr.untyped());
    let match_input = if is_zero(ctx, expr_b) {
        lower_expr_to_var_usage(ctx, builder, expr_a)?
    } else if is_zero(ctx, expr_a) {
        lower_expr_to_var_usage(ctx, builder, expr_b)?
    } else {
        let lowered_a = lower_expr_to_var_usage(ctx, builder, expr_a)?;
        let lowered_b = lower_expr_to_var_usage(ctx, builder, expr_b)?;
        let ret_ty = corelib::core_felt252_ty(ctx.db.upcast());
        let call_result = generators::Call {
            function: corelib::felt252_sub(ctx.db.upcast()).lowered(ctx.db),
            inputs: vec![lowered_a, lowered_b],
            extra_ret_tys: vec![],
            ret_tys: vec![ret_ty],
            location: ctx
                .get_location(ctx.function_body.exprs[expr.condition].stable_ptr().untyped()),
        }
        .add(ctx, &mut builder.statements);
        call_result.returns.into_iter().next().unwrap()
    };

    let semantic_db = ctx.db.upcast();

    // Main block.
    let subscope_main = create_subscope_with_bound_refs(ctx, builder);
    let block_main_id = subscope_main.block_id;
    let body_expr = ctx.function_body.exprs[expr.if_block].clone();
    let block_main =
        lower_block(ctx, subscope_main, extract_matches!(&body_expr, semantic::Expr::Block))
            .map_err(LoweringFlowError::Failed)?;

    // Else block.
    let non_zero_type =
        corelib::core_nonzero_ty(semantic_db, corelib::core_felt252_ty(semantic_db));
    let subscope_else = create_subscope_with_bound_refs(ctx, builder);
    let block_else_id = subscope_else.block_id;

    let else_block_input_var_id =
        ctx.new_var(VarRequest { ty: non_zero_type, location: if_location });
    let block_else = lower_optional_else_block(ctx, subscope_else, expr.else_block, if_location)
        .map_err(LoweringFlowError::Failed)?;

    let match_info = MatchInfo::Extern(MatchExternInfo {
        function: corelib::core_felt252_is_zero(semantic_db).lowered(ctx.db),
        inputs: vec![match_input],
        arms: vec![
            MatchArm {
                variant_id: corelib::jump_nz_zero_variant(semantic_db),
                block_id: block_main_id,
                var_ids: vec![],
            },
            MatchArm {
                variant_id: corelib::jump_nz_nonzero_variant(semantic_db),
                block_id: block_else_id,
                var_ids: vec![else_block_input_var_id],
            },
        ],
        location: if_location,
    });
    builder.merge_and_end_with_match(ctx, match_info, vec![block_main, block_else], if_location)
}

/// Lowers an optional else block. If the else block is missing it is replaced with a block
/// returning a unit.
/// Returns the sealed block builder of the else block.
fn lower_optional_else_block(
    ctx: &mut LoweringContext<'_, '_>,
    mut builder: BlockBuilder,
    else_expr_opt: Option<semantic::ExprId>,
    if_location: LocationId,
) -> Maybe<SealedBlockBuilder> {
    log::trace!("Started lowering of an optional else block.");
    match else_expr_opt {
        Some(else_expr) => {
            let expr = ctx.function_body.exprs[else_expr].clone();
            match &expr {
                semantic::Expr::Block(block) => lower_block(ctx, builder, block),
                semantic::Expr::If(if_expr) => {
                    let lowered_if = lower_expr_if(ctx, &mut builder, if_expr);
                    lowered_expr_to_block_scope_end(ctx, builder, lowered_if)
                }
                _ => unreachable!(),
            }
        }
        None => lowered_expr_to_block_scope_end(
            ctx,
            builder,
            Ok(LoweredExpr::Tuple { exprs: vec![], location: if_location }),
        ),
    }
}