cairo_lang_sierra_generator/
function_generator.rs1#[cfg(test)]
2#[path = "function_generator_test.rs"]
3mod test;
4
5use cairo_lang_diagnostics::Maybe;
6use cairo_lang_lowering::db::LoweringGroup;
7use cairo_lang_lowering::ids::ConcreteFunctionWithBodyId;
8use cairo_lang_lowering::{self as lowering, Lowered, LoweringStage};
9use cairo_lang_sierra::extensions::lib_func::SierraApChange;
10use cairo_lang_sierra::ids::ConcreteLibfuncId;
11use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
12use cairo_lang_utils::ordered_hash_set::OrderedHashSet;
13use itertools::{Itertools, zip_eq};
14use salsa::Database;
15
16use crate::block_generator::generate_function_statements;
17use crate::db::SierraGenGroup;
18use crate::expr_generator_context::ExprGeneratorContext;
19use crate::lifetime::{SierraGenVar, find_variable_lifetime};
20use crate::local_variables::{AnalyzeApChangesResult, analyze_ap_changes};
21use crate::pre_sierra;
22use crate::store_variables::{LocalVariables, add_store_statements};
23use crate::utils::{
24 alloc_local_libfunc_id, disable_ap_tracking_libfunc_id, dummy_call_libfunc_id,
25 finalize_locals_libfunc_id, get_concrete_libfunc_id, get_libfunc_signature, return_statement,
26 revoke_ap_tracking_libfunc_id, simple_basic_statement,
27};
28
29#[derive(Clone, Debug, PartialEq, Eq, salsa::Update)]
30pub struct SierraFunctionWithBodyData<'db> {
31 pub function: Maybe<pre_sierra::Function<'db>>,
32 pub ap_change: SierraApChange,
33}
34
35#[salsa::tracked(returns(ref))]
37pub fn priv_function_with_body_sierra_data<'db>(
38 db: &'db dyn Database,
39 function_id: ConcreteFunctionWithBodyId<'db>,
40) -> Maybe<SierraFunctionWithBodyData<'db>> {
41 let lowered_function = db.lowered_body(function_id, LoweringStage::Final)?;
42 lowered_function.blocks.has_root()?;
43
44 let analyze_ap_changes_result = analyze_ap_changes(db, lowered_function)?;
46
47 let ap_change = match analyze_ap_changes_result.known_ap_change {
48 true => SierraApChange::Known { new_vars_only: false },
49 false => SierraApChange::Unknown,
50 };
51
52 let function = get_function_ap_change_and_code(
53 db,
54 function_id,
55 lowered_function,
56 analyze_ap_changes_result,
57 );
58 Ok(SierraFunctionWithBodyData { ap_change, function })
59}
60
61fn get_function_ap_change_and_code<'db>(
62 db: &'db dyn Database,
63 function_id: ConcreteFunctionWithBodyId<'db>,
64 lowered_function: &Lowered<'db>,
65 analyze_ap_change_result: AnalyzeApChangesResult,
66) -> Maybe<pre_sierra::Function<'db>> {
67 let root_block = lowered_function.blocks.root_block()?;
68 let AnalyzeApChangesResult { known_ap_change, local_variables, ap_tracking_configuration } =
69 analyze_ap_change_result;
70
71 let lifetime = find_variable_lifetime(lowered_function, &local_variables)?;
73
74 let mut context = ExprGeneratorContext::new(
75 db,
76 lowered_function,
77 function_id,
78 &lifetime,
79 ap_tracking_configuration,
80 );
81
82 if let Some(lowering::Statement::Call(call_stmt)) = root_block.statements.first() {
85 if get_concrete_libfunc_id(db, call_stmt.function, false).1
86 == revoke_ap_tracking_libfunc_id(db)
87 {
88 context.set_ap_tracking(false);
89 }
90 }
91
92 let (label, label_id) = context.new_label();
94
95 let parameters = lowered_function
97 .parameters
98 .iter()
99 .map(|param_id| {
100 Ok(cairo_lang_sierra::program::Param {
101 id: context.get_sierra_variable(*param_id),
102 ty: db.get_concrete_type_id(lowered_function.variables[*param_id].ty)?.clone(),
103 })
104 })
105 .collect::<Result<Vec<_>, _>>()?;
106
107 context.push_statement(label);
108
109 let sierra_local_variables = allocate_local_variables(&mut context, &local_variables)?;
110
111 if !known_ap_change && context.get_ap_tracking() {
114 context.push_statement(simple_basic_statement(
115 disable_ap_tracking_libfunc_id(db),
116 &[],
117 &[],
118 ));
119 context.set_ap_tracking(false);
120 }
121
122 let statements = generate_function_statements(context)?;
124
125 let statements = add_store_statements(
126 db,
127 statements,
128 &|id: ConcreteLibfuncId| get_libfunc_signature(db, &id),
129 sierra_local_variables,
130 ¶meters,
131 );
132
133 Ok(pre_sierra::Function {
136 id: db.intern_sierra_function(function_id.function_id(db)?),
137 body: statements,
138 entry_point: label_id,
139 parameters,
140 })
141}
142
143#[salsa::tracked(returns(ref))]
145pub fn priv_get_dummy_function<'db>(
146 db: &'db dyn Database,
147 function_id: ConcreteFunctionWithBodyId<'db>,
148) -> Maybe<pre_sierra::Function<'db>> {
149 let lowered_function = db.lowered_body(function_id, LoweringStage::PreOptimizations)?;
151 let ap_tracking_configuration = Default::default();
152 let lifetime = Default::default();
153
154 let mut context = ExprGeneratorContext::new(
155 db,
156 lowered_function,
157 function_id,
158 &lifetime,
159 ap_tracking_configuration,
160 );
161
162 let (label, label_id) = context.new_label();
164 context.push_statement(label);
165
166 let sierra_id = db.intern_sierra_function(function_id.function_id(db)?);
167 let sierra_signature = db.get_function_signature(sierra_id.clone()).unwrap();
168
169 let param_vars = (0..sierra_signature.param_types.len() as u64)
170 .map(cairo_lang_sierra::ids::VarId::new)
171 .collect_vec();
172
173 let ret_vars = (0..sierra_signature.ret_types.len() as u64)
174 .map(cairo_lang_sierra::ids::VarId::new)
175 .collect_vec();
176
177 let parameters = zip_eq(¶m_vars, &sierra_signature.param_types)
179 .map(|(id, ty)| Ok(cairo_lang_sierra::program::Param { id: id.clone(), ty: ty.clone() }))
180 .collect::<Result<Vec<_>, _>>()?;
181
182 context.push_statement(simple_basic_statement(
183 dummy_call_libfunc_id(db, sierra_id, sierra_signature),
184 ¶m_vars[..],
185 &ret_vars[..],
186 ));
187
188 context.push_statement(return_statement(ret_vars));
189
190 Ok(pre_sierra::Function {
191 id: db.intern_sierra_function(function_id.function_id(db)?),
192 body: context.statements(),
193 entry_point: label_id,
194 parameters,
195 })
196}
197
198fn allocate_local_variables<'db>(
204 context: &mut ExprGeneratorContext<'db, '_>,
205 local_variables: &OrderedHashSet<lowering::VariableId>,
206) -> Maybe<LocalVariables> {
207 let mut sierra_local_variables =
208 OrderedHashMap::<cairo_lang_sierra::ids::VarId, cairo_lang_sierra::ids::VarId>::default();
209 for lowering_var_id in local_variables {
210 let sierra_var_id = context.get_sierra_variable(*lowering_var_id);
211 let uninitialized_local_var_id =
212 context.get_sierra_variable(SierraGenVar::UninitializedLocal(*lowering_var_id));
213 context.push_statement(simple_basic_statement(
214 alloc_local_libfunc_id(
215 context.get_db(),
216 context.get_variable_sierra_type(*lowering_var_id)?,
217 ),
218 &[],
219 std::slice::from_ref(&uninitialized_local_var_id),
220 ));
221
222 sierra_local_variables.insert(sierra_var_id, uninitialized_local_var_id);
223 }
224
225 if !local_variables.is_empty() {
227 context.push_statement(simple_basic_statement(
228 finalize_locals_libfunc_id(context.get_db()),
229 &[],
230 &[],
231 ));
232 }
233
234 Ok(sierra_local_variables)
235}