cairo_lang_sierra_generator/
function_generator.rs

1#[cfg(test)]
2#[path = "function_generator_test.rs"]
3mod test;
4
5use cairo_lang_diagnostics::Maybe;
6use cairo_lang_lowering::db::LoweringGroup;
7use cairo_lang_lowering::ids::ConcreteFunctionWithBodyId;
8use cairo_lang_lowering::{self as lowering, Lowered, LoweringStage};
9use cairo_lang_sierra::extensions::lib_func::SierraApChange;
10use cairo_lang_sierra::ids::ConcreteLibfuncId;
11use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
12use cairo_lang_utils::ordered_hash_set::OrderedHashSet;
13use itertools::{Itertools, zip_eq};
14use salsa::Database;
15
16use crate::block_generator::generate_function_statements;
17use crate::db::SierraGenGroup;
18use crate::expr_generator_context::ExprGeneratorContext;
19use crate::lifetime::{SierraGenVar, find_variable_lifetime};
20use crate::local_variables::{AnalyzeApChangesResult, analyze_ap_changes};
21use crate::pre_sierra;
22use crate::store_variables::{LocalVariables, add_store_statements};
23use crate::utils::{
24    alloc_local_libfunc_id, disable_ap_tracking_libfunc_id, dummy_call_libfunc_id,
25    finalize_locals_libfunc_id, get_concrete_libfunc_id, get_libfunc_signature, return_statement,
26    revoke_ap_tracking_libfunc_id, simple_basic_statement,
27};
28
29#[derive(Clone, Debug, PartialEq, Eq, salsa::Update)]
30pub struct SierraFunctionWithBodyData<'db> {
31    pub function: Maybe<pre_sierra::Function<'db>>,
32    pub ap_change: SierraApChange,
33}
34
35/// Query implementation of [SierraGenGroup::priv_function_with_body_sierra_data].
36#[salsa::tracked(returns(ref))]
37pub fn priv_function_with_body_sierra_data<'db>(
38    db: &'db dyn Database,
39    function_id: ConcreteFunctionWithBodyId<'db>,
40) -> Maybe<SierraFunctionWithBodyData<'db>> {
41    let lowered_function = db.lowered_body(function_id, LoweringStage::Final)?;
42    lowered_function.blocks.has_root()?;
43
44    // Find the local variables.
45    let analyze_ap_changes_result = analyze_ap_changes(db, lowered_function)?;
46
47    let ap_change = match analyze_ap_changes_result.known_ap_change {
48        true => SierraApChange::Known { new_vars_only: false },
49        false => SierraApChange::Unknown,
50    };
51
52    let function = get_function_ap_change_and_code(
53        db,
54        function_id,
55        lowered_function,
56        analyze_ap_changes_result,
57    );
58    Ok(SierraFunctionWithBodyData { ap_change, function })
59}
60
61fn get_function_ap_change_and_code<'db>(
62    db: &'db dyn Database,
63    function_id: ConcreteFunctionWithBodyId<'db>,
64    lowered_function: &Lowered<'db>,
65    analyze_ap_change_result: AnalyzeApChangesResult,
66) -> Maybe<pre_sierra::Function<'db>> {
67    let root_block = lowered_function.blocks.root_block()?;
68    let AnalyzeApChangesResult { known_ap_change, local_variables, ap_tracking_configuration } =
69        analyze_ap_change_result;
70
71    // Get lifetime information.
72    let lifetime = find_variable_lifetime(lowered_function, &local_variables)?;
73
74    let mut context = ExprGeneratorContext::new(
75        db,
76        lowered_function,
77        function_id,
78        &lifetime,
79        ap_tracking_configuration,
80    );
81
82    // If the function starts with `revoke_ap_tracking` then we can avoid
83    // the first `disable_ap_tracking`.
84    if let Some(lowering::Statement::Call(call_stmt)) = root_block.statements.first() {
85        if get_concrete_libfunc_id(db, call_stmt.function, false).1
86            == revoke_ap_tracking_libfunc_id(db)
87        {
88            context.set_ap_tracking(false);
89        }
90    }
91
92    // Generate a label for the function's body.
93    let (label, label_id) = context.new_label();
94
95    // Generate Sierra variables for the function parameters.
96    let parameters = lowered_function
97        .parameters
98        .iter()
99        .map(|param_id| {
100            Ok(cairo_lang_sierra::program::Param {
101                id: context.get_sierra_variable(*param_id),
102                ty: db.get_concrete_type_id(lowered_function.variables[*param_id].ty)?.clone(),
103            })
104        })
105        .collect::<Result<Vec<_>, _>>()?;
106
107    context.push_statement(label);
108
109    let sierra_local_variables = allocate_local_variables(&mut context, &local_variables)?;
110
111    // Revoking ap tracking as the first non-local command for unknown ap-change function, to allow
112    // proper ap-equation solving. TODO(orizi): Fix the solver to not require this constraint.
113    if !known_ap_change && context.get_ap_tracking() {
114        context.push_statement(simple_basic_statement(
115            disable_ap_tracking_libfunc_id(db),
116            &[],
117            &[],
118        ));
119        context.set_ap_tracking(false);
120    }
121
122    // Generate the function's code.
123    let statements = generate_function_statements(context)?;
124
125    let statements = add_store_statements(
126        db,
127        statements,
128        &|id: ConcreteLibfuncId| get_libfunc_signature(db, &id),
129        sierra_local_variables,
130        &parameters,
131    );
132
133    // TODO(spapini): Don't intern objects for the semantic model outside the crate. These should
134    // be regarded as private.
135    Ok(pre_sierra::Function {
136        id: db.intern_sierra_function(function_id.function_id(db)?),
137        body: statements,
138        entry_point: label_id,
139        parameters,
140    })
141}
142
143/// Query implementation of [SierraGenGroup::priv_get_dummy_function].
144#[salsa::tracked(returns(ref))]
145pub fn priv_get_dummy_function<'db>(
146    db: &'db dyn Database,
147    function_id: ConcreteFunctionWithBodyId<'db>,
148) -> Maybe<pre_sierra::Function<'db>> {
149    // TODO(ilya): Remove the following query.
150    let lowered_function = db.lowered_body(function_id, LoweringStage::PreOptimizations)?;
151    let ap_tracking_configuration = Default::default();
152    let lifetime = Default::default();
153
154    let mut context = ExprGeneratorContext::new(
155        db,
156        lowered_function,
157        function_id,
158        &lifetime,
159        ap_tracking_configuration,
160    );
161
162    // Generate a label for the function's body.
163    let (label, label_id) = context.new_label();
164    context.push_statement(label);
165
166    let sierra_id = db.intern_sierra_function(function_id.function_id(db)?);
167    let sierra_signature = db.get_function_signature(sierra_id.clone()).unwrap();
168
169    let param_vars = (0..sierra_signature.param_types.len() as u64)
170        .map(cairo_lang_sierra::ids::VarId::new)
171        .collect_vec();
172
173    let ret_vars = (0..sierra_signature.ret_types.len() as u64)
174        .map(cairo_lang_sierra::ids::VarId::new)
175        .collect_vec();
176
177    // Generate Sierra variables for the function parameters.
178    let parameters = zip_eq(&param_vars, &sierra_signature.param_types)
179        .map(|(id, ty)| Ok(cairo_lang_sierra::program::Param { id: id.clone(), ty: ty.clone() }))
180        .collect::<Result<Vec<_>, _>>()?;
181
182    context.push_statement(simple_basic_statement(
183        dummy_call_libfunc_id(db, sierra_id, sierra_signature),
184        &param_vars[..],
185        &ret_vars[..],
186    ));
187
188    context.push_statement(return_statement(ret_vars));
189
190    Ok(pre_sierra::Function {
191        id: db.intern_sierra_function(function_id.function_id(db)?),
192        body: context.statements(),
193        entry_point: label_id,
194        parameters,
195    })
196}
197
198/// Allocates space for the local variables.
199/// Returns:
200/// * A map from a Sierra variable that should be stored as local variable to its allocated space
201///   (uninitialized local variable).
202/// * A list of Sierra statements.
203fn allocate_local_variables<'db>(
204    context: &mut ExprGeneratorContext<'db, '_>,
205    local_variables: &OrderedHashSet<lowering::VariableId>,
206) -> Maybe<LocalVariables> {
207    let mut sierra_local_variables =
208        OrderedHashMap::<cairo_lang_sierra::ids::VarId, cairo_lang_sierra::ids::VarId>::default();
209    for lowering_var_id in local_variables {
210        let sierra_var_id = context.get_sierra_variable(*lowering_var_id);
211        let uninitialized_local_var_id =
212            context.get_sierra_variable(SierraGenVar::UninitializedLocal(*lowering_var_id));
213        context.push_statement(simple_basic_statement(
214            alloc_local_libfunc_id(
215                context.get_db(),
216                context.get_variable_sierra_type(*lowering_var_id)?,
217            ),
218            &[],
219            std::slice::from_ref(&uninitialized_local_var_id),
220        ));
221
222        sierra_local_variables.insert(sierra_var_id, uninitialized_local_var_id);
223    }
224
225    // Add finalize_locals() statement.
226    if !local_variables.is_empty() {
227        context.push_statement(simple_basic_statement(
228            finalize_locals_libfunc_id(context.get_db()),
229            &[],
230            &[],
231        ));
232    }
233
234    Ok(sierra_local_variables)
235}