cairo_lang_sierra_generator/
program_generator.rs1use std::collections::VecDeque;
2
3use cairo_lang_debug::DebugWithDb;
4use cairo_lang_defs::db::DefsGroup;
5use cairo_lang_diagnostics::Maybe;
6use cairo_lang_filesystem::ids::{CrateId, Tracked};
7use cairo_lang_filesystem::location_marks::get_location_marks;
8use cairo_lang_lowering::ids::{ConcreteFunctionWithBodyId, LocationId};
9use cairo_lang_sierra::extensions::GenericLibfuncEx;
10use cairo_lang_sierra::extensions::core::CoreLibfunc;
11use cairo_lang_sierra::ids::{ConcreteLibfuncId, ConcreteTypeId, FunctionId, VarId};
12use cairo_lang_sierra::program::{self, DeclaredTypeInfo, Program, StatementIdx};
13use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
14use cairo_lang_utils::ordered_hash_set::OrderedHashSet;
15use cairo_lang_utils::try_extract_matches;
16use cairo_lang_utils::unordered_hash_set::UnorderedHashSet;
17use itertools::{Itertools, chain};
18use salsa::Database;
19
20use crate::db::{SierraGenGroup, sierra_concrete_long_id};
21use crate::extra_sierra_info::type_has_const_size;
22use crate::pre_sierra;
23use crate::replace_ids::{DebugReplacer, SierraIdReplacer};
24use crate::resolve_labels::{LabelReplacer, resolve_labels_and_extract_locations};
25use crate::specialization_context::SierraSignatureSpecializationContext;
26use crate::statements_locations::StatementsLocations;
27
28#[cfg(test)]
29#[path = "program_generator_test.rs"]
30mod test;
31
32fn collect_and_generate_libfunc_declarations<'db>(
35 db: &dyn Database,
36 statements: &[pre_sierra::StatementWithLocation<'db>],
37) -> Vec<program::LibfuncDeclaration> {
38 let mut declared_libfuncs = UnorderedHashSet::<ConcreteLibfuncId>::default();
39 statements
40 .iter()
41 .filter_map(|statement| match &statement.statement {
42 pre_sierra::Statement::Sierra(program::GenStatement::Invocation(invocation)) => {
43 declared_libfuncs.insert(invocation.libfunc_id.clone()).then(|| {
44 program::LibfuncDeclaration {
45 id: invocation.libfunc_id.clone(),
46 long_id: db.lookup_concrete_lib_func(&invocation.libfunc_id),
47 }
48 })
49 }
50 pre_sierra::Statement::Sierra(program::GenStatement::Return(_))
51 | pre_sierra::Statement::Label(_) => None,
52 pre_sierra::Statement::PushValues(_) => {
53 panic!("Unexpected pre_sierra::Statement::PushValues in collect_used_libfuncs().")
54 }
55 })
56 .collect()
57}
58
59fn generate_type_declarations(
62 db: &dyn Database,
63 libfunc_declarations: &[program::LibfuncDeclaration],
64 functions: &[program::Function],
65) -> Vec<program::TypeDeclaration> {
66 let mut declarations = vec![];
67 let mut already_declared = UnorderedHashSet::default();
68 let mut remaining_types = collect_used_types(db, libfunc_declarations, functions);
69 while let Some(ty) = remaining_types.iter().next().cloned() {
70 remaining_types.swap_remove(&ty);
71 generate_type_declarations_helper(
72 db,
73 &ty,
74 &mut declarations,
75 &mut remaining_types,
76 &mut already_declared,
77 );
78 }
79 declarations
80}
81
82fn generate_type_declarations_helper(
87 db: &dyn Database,
88 ty: &ConcreteTypeId,
89 declarations: &mut Vec<program::TypeDeclaration>,
90 remaining_types: &mut OrderedHashSet<ConcreteTypeId>,
91 already_declared: &mut UnorderedHashSet<ConcreteTypeId>,
92) {
93 if already_declared.contains(ty) {
94 return;
95 }
96 let long_id = sierra_concrete_long_id(db, ty.clone()).unwrap();
97 already_declared.insert(ty.clone());
98 let inner_tys = long_id
99 .generic_args
100 .iter()
101 .filter_map(|arg| try_extract_matches!(arg, program::GenericArg::Type));
102 if type_has_const_size(&long_id.generic_id) {
105 remaining_types.extend(inner_tys.cloned());
106 } else {
107 for inner_ty in inner_tys {
108 generate_type_declarations_helper(
109 db,
110 inner_ty,
111 declarations,
112 remaining_types,
113 already_declared,
114 );
115 }
116 }
117
118 let type_info = db.get_type_info(ty.clone()).unwrap();
119 declarations.push(program::TypeDeclaration {
120 id: ty.clone(),
121 long_id: long_id.as_ref().clone(),
122 declared_type_info: Some(DeclaredTypeInfo {
123 storable: type_info.storable,
124 droppable: type_info.droppable,
125 duplicatable: type_info.duplicatable,
126 zero_sized: type_info.zero_sized,
127 }),
128 });
129}
130
131fn collect_used_types(
134 db: &dyn Database,
135 libfunc_declarations: &[program::LibfuncDeclaration],
136 functions: &[program::Function],
137) -> OrderedHashSet<ConcreteTypeId> {
138 let mut all_types = OrderedHashSet::default();
139 for libfunc in libfunc_declarations {
141 let types = db.priv_libfunc_dependencies(libfunc.id.clone());
142 all_types.extend(types.iter().cloned());
143 }
144
145 all_types.extend(
153 functions.iter().flat_map(|func| {
154 chain!(&func.signature.param_types, &func.signature.ret_types).cloned()
155 }),
156 );
157 all_types
158}
159
160#[salsa::tracked(returns(ref))]
162pub fn priv_libfunc_dependencies(
163 db: &dyn Database,
164 _tracked: Tracked,
165 libfunc_id: ConcreteLibfuncId,
166) -> Vec<ConcreteTypeId> {
167 let long_id = db.lookup_concrete_lib_func(&libfunc_id);
168 let signature = CoreLibfunc::specialize_signature_by_id(
169 &SierraSignatureSpecializationContext(db),
170 &long_id.generic_id,
171 &long_id.generic_args,
172 )
173 .unwrap_or_else(|err| panic!("Failed to specialize: `{}`. Error: {err}",
176 DebugReplacer { db }.replace_libfunc_id(&libfunc_id)));
177 let mut all_types = vec![];
179 let mut add_ty = |ty: ConcreteTypeId| {
180 if !all_types.contains(&ty) {
181 all_types.push(ty);
182 }
183 };
184 for param_signature in signature.param_signatures {
185 add_ty(param_signature.ty);
186 }
187 for info in signature.branch_signatures {
188 for var in info.vars {
189 add_ty(var.ty);
190 }
191 }
192 for arg in long_id.generic_args {
193 if let program::GenericArg::Type(ty) = arg {
194 add_ty(ty);
195 }
196 }
197 all_types
198}
199
200#[derive(Clone, Debug, Eq, PartialEq)]
201pub struct SierraProgramWithDebug<'db> {
202 pub program: cairo_lang_sierra::program::Program,
203 pub debug_info: SierraProgramDebugInfo<'db>,
204}
205
206unsafe impl<'db> salsa::Update for SierraProgramWithDebug<'db> {
207 unsafe fn maybe_update(old_pointer: *mut Self, new_value: Self) -> bool {
208 let old_value = unsafe { &mut *old_pointer };
209 if old_value == &new_value {
210 return false;
211 }
212 *old_value = new_value;
213 true
214 }
215}
216impl<'db> DebugWithDb<'db> for SierraProgramWithDebug<'db> {
219 type Db = dyn Database;
220
221 fn fmt(&self, f: &mut std::fmt::Formatter<'_>, db: &dyn Database) -> std::fmt::Result {
222 let sierra_program = DebugReplacer { db }.apply(&self.program);
223 for declaration in &sierra_program.type_declarations {
224 writeln!(f, "{declaration};")?;
225 }
226 writeln!(f)?;
227 for declaration in &sierra_program.libfunc_declarations {
228 writeln!(f, "{declaration};")?;
229 }
230 writeln!(f)?;
231 let mut funcs = sierra_program.funcs.iter().peekable();
232 while let Some(func) = funcs.next() {
233 let start = func.entry_point.0;
234 let end = funcs
235 .peek()
236 .map(|f| f.entry_point.0)
237 .unwrap_or_else(|| sierra_program.statements.len());
238 writeln!(f, "// {}:", func.id)?;
239 for param in &func.params {
240 writeln!(f, "// {param}")?;
241 }
242 for i in start..end {
243 writeln!(f, "{}; // {i}", sierra_program.statements[i])?;
244 if let Some(loc) =
245 &self.debug_info.statements_locations.locations.get(&StatementIdx(i))
246 {
247 let loc = get_location_marks(db, &loc.first().unwrap().span_in_file(db), true);
248 println!("{}", loc.split('\n').map(|l| format!("// {l}")).join("\n"));
249 }
250 }
251 }
252 writeln!(f)?;
253 for func in &sierra_program.funcs {
254 writeln!(f, "{func};")?;
255 }
256 Ok(())
257 }
258}
259
260#[derive(Clone, Debug, Eq, PartialEq)]
261pub struct SierraProgramDebugInfo<'db> {
262 pub statements_locations: StatementsLocations<'db>,
263 pub variable_location: OrderedHashMap<FunctionId, OrderedHashMap<VarId, LocationId<'db>>>,
264}
265
266#[salsa::tracked(returns(ref))]
267pub fn get_sierra_program_for_functions<'db>(
268 db: &'db dyn Database,
269 _tracked: Tracked,
270 requested_function_ids: Vec<ConcreteFunctionWithBodyId<'db>>,
271) -> Maybe<SierraProgramWithDebug<'db>> {
272 let mut functions: Vec<&'db pre_sierra::Function<'_>> = vec![];
273 let mut statements: Vec<pre_sierra::StatementWithLocation<'_>> = vec![];
274 let mut processed_function_ids = UnorderedHashSet::<ConcreteFunctionWithBodyId<'_>>::default();
275 let mut function_id_queue: VecDeque<ConcreteFunctionWithBodyId<'_>> =
276 requested_function_ids.into_iter().collect();
277 while let Some(function_id) = function_id_queue.pop_front() {
278 if !processed_function_ids.insert(function_id) {
279 continue;
280 }
281 let function = db.function_with_body_sierra(function_id)?;
282 functions.push(function);
283 statements.extend_from_slice(&function.body);
284
285 for statement in &function.body {
286 if let Some(related_function_id) = try_get_function_with_body_id(db, statement) {
287 function_id_queue.push_back(related_function_id);
288 }
289 }
290 }
291
292 let AssembledProgram { program, statements_locations, variable_location } =
293 assemble_program(db, functions, statements);
294 Ok(SierraProgramWithDebug {
295 program,
296 debug_info: SierraProgramDebugInfo {
297 statements_locations: StatementsLocations::from_locations_vec(db, statements_locations),
298 variable_location,
299 },
300 })
301}
302
303struct AssembledProgram<'db> {
305 program: program::Program,
307 statements_locations: Vec<Option<LocationId<'db>>>,
309 variable_location: OrderedHashMap<FunctionId, OrderedHashMap<VarId, LocationId<'db>>>,
311}
312
313fn assemble_program<'db>(
316 db: &dyn Database,
317 functions: Vec<&'db pre_sierra::Function<'db>>,
318 statements: Vec<pre_sierra::StatementWithLocation<'db>>,
319) -> AssembledProgram<'db> {
320 let label_replacer = LabelReplacer::from_statements(&statements);
321 let variable_location = functions
322 .iter()
323 .map(|f| (f.id.clone(), f.variable_locations.iter().cloned().collect()))
324 .collect();
325 let funcs = functions
326 .into_iter()
327 .map(|function| {
328 let sierra_signature = db.get_function_signature(function.id.clone()).unwrap();
329 program::Function::new(
330 function.id.clone(),
331 function.parameters.clone(),
332 sierra_signature.ret_types.clone(),
333 label_replacer.handle_label_id(function.entry_point),
334 )
335 })
336 .collect_vec();
337
338 let libfunc_declarations = collect_and_generate_libfunc_declarations(db, &statements);
339 let type_declarations = generate_type_declarations(db, &libfunc_declarations, &funcs);
340 let (resolved_statements, statements_locations) =
342 resolve_labels_and_extract_locations(statements, &label_replacer);
343 let program = program::Program {
344 type_declarations,
345 libfunc_declarations,
346 statements: resolved_statements,
347 funcs,
348 };
349 AssembledProgram { program, statements_locations, variable_location }
350}
351
352pub fn try_get_function_with_body_id<'db>(
354 db: &'db dyn Database,
355 statement: &pre_sierra::StatementWithLocation<'db>,
356) -> Option<ConcreteFunctionWithBodyId<'db>> {
357 let invc = try_extract_matches!(
358 try_extract_matches!(&statement.statement, pre_sierra::Statement::Sierra)?,
359 program::GenStatement::Invocation
360 )?;
361 let libfunc = db.lookup_concrete_lib_func(&invc.libfunc_id);
362 let inner_function = if libfunc.generic_id == "function_call".into()
363 || libfunc.generic_id == "coupon_call".into()
364 {
365 libfunc.generic_args.first()?.clone()
366 } else if libfunc.generic_id == "coupon_buy".into()
367 || libfunc.generic_id == "coupon_refund".into()
368 {
369 let coupon_ty = try_extract_matches!(
374 libfunc.generic_args.first()?,
375 cairo_lang_sierra::program::GenericArg::Type
376 )?;
377 let coupon_long_id = sierra_concrete_long_id(db, coupon_ty.clone()).unwrap();
378 coupon_long_id.generic_args.first()?.clone()
379 } else {
380 return None;
381 };
382
383 db.lookup_sierra_function(&try_extract_matches!(
384 inner_function,
385 cairo_lang_sierra::program::GenericArg::UserFunc
386 )?)
387 .body(db)
388 .expect("No diagnostics at this stage.")
389}
390
391#[salsa::tracked(returns(ref))]
392pub fn get_sierra_program<'db>(
393 db: &'db dyn Database,
394 _tracked: Tracked,
395 requested_crate_ids: Vec<CrateId<'db>>,
396) -> Maybe<SierraProgramWithDebug<'db>> {
397 let requested_function_ids = find_all_free_function_ids(db, requested_crate_ids)?;
398 db.get_sierra_program_for_functions(requested_function_ids).cloned()
399}
400
401pub fn find_all_free_function_ids<'db>(
403 db: &'db dyn Database,
404 requested_crate_ids: Vec<CrateId<'db>>,
405) -> Maybe<Vec<ConcreteFunctionWithBodyId<'db>>> {
406 let mut requested_function_ids = vec![];
407 for crate_id in requested_crate_ids {
408 for module_id in db.crate_modules(crate_id).iter() {
409 for (free_func_id, _) in module_id.module_data(db)?.free_functions(db).iter() {
410 if let Some(function) =
412 ConcreteFunctionWithBodyId::from_no_generics_free(db, *free_func_id)
413 {
414 requested_function_ids.push(function)
415 }
416 }
417 }
418 }
419 Ok(requested_function_ids)
420}
421
422pub fn get_dummy_program_for_size_estimation(
427 db: &dyn Database,
428 function_id: ConcreteFunctionWithBodyId<'_>,
429) -> Maybe<Program> {
430 let function = db.function_with_body_sierra(function_id)?;
431
432 let mut processed_function_ids =
433 UnorderedHashSet::<ConcreteFunctionWithBodyId<'_>>::from_iter([function_id]);
434
435 let mut functions = vec![function];
436
437 for statement in &function.body {
438 if let Some(function_id) = try_get_function_with_body_id(db, statement) {
439 if processed_function_ids.insert(function_id) {
440 functions.push(db.priv_get_dummy_function(function_id)?);
441 }
442 }
443 }
444 let statements = functions
446 .iter()
447 .flat_map(|f| f.body.iter())
448 .map(|s| s.statement.clone().into_statement_without_location())
449 .collect();
450
451 Ok(assemble_program(db, functions, statements).program)
452}