cairo_lang_sierra_generator/
program_generator.rs1use std::collections::VecDeque;
2
3use cairo_lang_debug::DebugWithDb;
4use cairo_lang_defs::db::DefsGroup;
5use cairo_lang_diagnostics::Maybe;
6use cairo_lang_filesystem::ids::{CrateId, Tracked};
7use cairo_lang_filesystem::location_marks::get_location_marks;
8use cairo_lang_lowering::ids::{ConcreteFunctionWithBodyId, LocationId};
9use cairo_lang_sierra::extensions::GenericLibfuncEx;
10use cairo_lang_sierra::extensions::core::CoreLibfunc;
11use cairo_lang_sierra::ids::{ConcreteLibfuncId, ConcreteTypeId, FunctionId};
12use cairo_lang_sierra::program::{self, DeclaredTypeInfo, Program, StatementIdx};
13use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
14use cairo_lang_utils::ordered_hash_set::OrderedHashSet;
15use cairo_lang_utils::try_extract_matches;
16use cairo_lang_utils::unordered_hash_set::UnorderedHashSet;
17use itertools::{Itertools, chain};
18use salsa::Database;
19
20use crate::db::{SierraGenGroup, sierra_concrete_long_id};
21use crate::debug_info::{
22 AllFunctionsDebugInfo, FunctionDebugInfo, SierraProgramDebugInfo, StatementsLocations,
23};
24use crate::extra_sierra_info::type_has_const_size;
25use crate::pre_sierra;
26use crate::replace_ids::{DebugReplacer, SierraIdReplacer};
27use crate::resolve_labels::{LabelReplacer, resolve_labels_and_extract_locations};
28use crate::specialization_context::SierraSignatureSpecializationContext;
29
30#[cfg(test)]
31#[path = "program_generator_test.rs"]
32mod test;
33
34fn collect_and_generate_libfunc_declarations<'db>(
37 db: &dyn Database,
38 statements: &[pre_sierra::StatementWithLocation<'db>],
39) -> Vec<program::LibfuncDeclaration> {
40 let mut declared_libfuncs = UnorderedHashSet::<ConcreteLibfuncId>::default();
41 statements
42 .iter()
43 .filter_map(|statement| match &statement.statement {
44 pre_sierra::Statement::Sierra(program::GenStatement::Invocation(invocation)) => {
45 declared_libfuncs.insert(invocation.libfunc_id.clone()).then(|| {
46 program::LibfuncDeclaration {
47 id: invocation.libfunc_id.clone(),
48 long_id: db.lookup_concrete_lib_func(&invocation.libfunc_id),
49 }
50 })
51 }
52 pre_sierra::Statement::Sierra(program::GenStatement::Return(_))
53 | pre_sierra::Statement::Label(_) => None,
54 pre_sierra::Statement::PushValues(_) => {
55 panic!(
56 "Unexpected pre_sierra::Statement::PushValues in \
57 collect_and_generate_libfunc_declarations()."
58 )
59 }
60 })
61 .collect()
62}
63
64fn generate_type_declarations(
67 db: &dyn Database,
68 libfunc_declarations: &[program::LibfuncDeclaration],
69 functions: &[program::Function],
70) -> Vec<program::TypeDeclaration> {
71 let mut declarations = vec![];
72 let mut already_declared = UnorderedHashSet::default();
73 let mut remaining_types = collect_used_types(db, libfunc_declarations, functions);
74 while let Some(ty) = remaining_types.iter().next().cloned() {
75 remaining_types.swap_remove(&ty);
76 generate_type_declarations_helper(
77 db,
78 &ty,
79 &mut declarations,
80 &mut remaining_types,
81 &mut already_declared,
82 );
83 }
84 declarations
85}
86
87fn generate_type_declarations_helper(
92 db: &dyn Database,
93 ty: &ConcreteTypeId,
94 declarations: &mut Vec<program::TypeDeclaration>,
95 remaining_types: &mut OrderedHashSet<ConcreteTypeId>,
96 already_declared: &mut UnorderedHashSet<ConcreteTypeId>,
97) {
98 if already_declared.contains(ty) {
99 return;
100 }
101 let long_id = sierra_concrete_long_id(db, ty.clone()).unwrap();
102 already_declared.insert(ty.clone());
103 let inner_tys = long_id
104 .generic_args
105 .iter()
106 .filter_map(|arg| try_extract_matches!(arg, program::GenericArg::Type));
107 if type_has_const_size(&long_id.generic_id) {
110 remaining_types.extend(inner_tys.cloned());
111 } else {
112 for inner_ty in inner_tys {
113 generate_type_declarations_helper(
114 db,
115 inner_ty,
116 declarations,
117 remaining_types,
118 already_declared,
119 );
120 }
121 }
122
123 let type_info = db.get_type_info(ty.clone()).unwrap();
124 declarations.push(program::TypeDeclaration {
125 id: ty.clone(),
126 long_id: long_id.as_ref().clone(),
127 declared_type_info: Some(DeclaredTypeInfo {
128 storable: type_info.storable,
129 droppable: type_info.droppable,
130 duplicatable: type_info.duplicatable,
131 zero_sized: type_info.zero_sized,
132 }),
133 });
134}
135
136fn collect_used_types(
139 db: &dyn Database,
140 libfunc_declarations: &[program::LibfuncDeclaration],
141 functions: &[program::Function],
142) -> OrderedHashSet<ConcreteTypeId> {
143 let mut all_types = OrderedHashSet::default();
144 for libfunc in libfunc_declarations {
146 let types = db.priv_libfunc_dependencies(libfunc.id.clone());
147 all_types.extend(types.iter().cloned());
148 }
149
150 all_types.extend(
158 functions.iter().flat_map(|func| {
159 chain!(&func.signature.param_types, &func.signature.ret_types).cloned()
160 }),
161 );
162 all_types
163}
164
165#[salsa::tracked(returns(ref))]
167pub fn priv_libfunc_dependencies(
168 db: &dyn Database,
169 _tracked: Tracked,
170 libfunc_id: ConcreteLibfuncId,
171) -> Vec<ConcreteTypeId> {
172 let long_id = db.lookup_concrete_lib_func(&libfunc_id);
173 let signature = CoreLibfunc::specialize_signature_by_id(
174 &SierraSignatureSpecializationContext(db),
175 &long_id.generic_id,
176 &long_id.generic_args,
177 )
178 .unwrap_or_else(|err| panic!("Failed to specialize: `{}`. Error: {err}",
181 DebugReplacer { db }.replace_libfunc_id(&libfunc_id)));
182 let mut all_types = vec![];
184 let mut add_ty = |ty: ConcreteTypeId| {
185 if !all_types.contains(&ty) {
186 all_types.push(ty);
187 }
188 };
189 for param_signature in signature.param_signatures {
190 add_ty(param_signature.ty);
191 }
192 for info in signature.branch_signatures {
193 for var in info.vars {
194 add_ty(var.ty);
195 }
196 }
197 for arg in long_id.generic_args {
198 if let program::GenericArg::Type(ty) = arg {
199 add_ty(ty);
200 }
201 }
202 all_types
203}
204
205#[derive(Clone, Debug, Eq, PartialEq)]
206pub struct SierraProgramWithDebug<'db> {
207 pub program: cairo_lang_sierra::program::Program,
208 pub debug_info: SierraProgramDebugInfo<'db>,
209}
210
211unsafe impl<'db> salsa::Update for SierraProgramWithDebug<'db> {
212 unsafe fn maybe_update(old_pointer: *mut Self, new_value: Self) -> bool {
213 let old_value = unsafe { &mut *old_pointer };
214 if old_value == &new_value {
215 return false;
216 }
217 *old_value = new_value;
218 true
219 }
220}
221impl<'db> DebugWithDb<'db> for SierraProgramWithDebug<'db> {
224 type Db = dyn Database;
225
226 fn fmt(&self, f: &mut std::fmt::Formatter<'_>, db: &dyn Database) -> std::fmt::Result {
227 let sierra_program = DebugReplacer { db }.apply(&self.program);
228 for declaration in &sierra_program.type_declarations {
229 writeln!(f, "{declaration};")?;
230 }
231 writeln!(f)?;
232 for declaration in &sierra_program.libfunc_declarations {
233 writeln!(f, "{declaration};")?;
234 }
235 writeln!(f)?;
236 let mut funcs = sierra_program.funcs.iter().peekable();
237 while let Some(func) = funcs.next() {
238 let start = func.entry_point.0;
239 let end = funcs
240 .peek()
241 .map(|f| f.entry_point.0)
242 .unwrap_or_else(|| sierra_program.statements.len());
243 writeln!(f, "// {}:", func.id)?;
244 for param in &func.params {
245 writeln!(f, "// {param}")?;
246 }
247 for i in start..end {
248 writeln!(f, "{}; // {i}", sierra_program.statements[i])?;
249 if let Some(loc) =
250 &self.debug_info.statements_locations.locations.get(&StatementIdx(i))
251 {
252 let loc = get_location_marks(db, &loc.first().unwrap().span_in_file(db), true);
253 println!("{}", loc.split('\n').map(|l| format!("// {l}")).join("\n"));
254 }
255 }
256 }
257 writeln!(f)?;
258 for func in &sierra_program.funcs {
259 writeln!(f, "{func};")?;
260 }
261 Ok(())
262 }
263}
264
265#[salsa::tracked(returns(ref))]
266pub fn get_sierra_program_for_functions<'db>(
267 db: &'db dyn Database,
268 _tracked: Tracked,
269 requested_function_ids: Vec<ConcreteFunctionWithBodyId<'db>>,
270) -> Maybe<SierraProgramWithDebug<'db>> {
271 let mut functions: Vec<&'db pre_sierra::Function<'_>> = vec![];
272 let mut statements: Vec<pre_sierra::StatementWithLocation<'_>> = vec![];
273 let mut processed_function_ids = UnorderedHashSet::<ConcreteFunctionWithBodyId<'_>>::default();
274 let mut function_id_queue: VecDeque<ConcreteFunctionWithBodyId<'_>> =
275 requested_function_ids.into_iter().collect();
276 while let Some(function_id) = function_id_queue.pop_front() {
277 if !processed_function_ids.insert(function_id) {
278 continue;
279 }
280 let function = db.function_with_body_sierra(function_id)?;
281 functions.push(function);
282 statements.extend_from_slice(&function.body);
283
284 for statement in &function.body {
285 if let Some(related_function_id) = try_get_function_with_body_id(db, statement) {
286 function_id_queue.push_back(related_function_id);
287 }
288 }
289 }
290
291 let AssembledProgram { program, statements_locations, functions_info } =
292 assemble_program(db, functions, statements);
293 Ok(SierraProgramWithDebug {
294 program,
295 debug_info: SierraProgramDebugInfo {
296 statements_locations: StatementsLocations::from_locations_vec(db, statements_locations),
297 functions_info: AllFunctionsDebugInfo::new(functions_info),
298 },
299 })
300}
301
302struct AssembledProgram<'db> {
304 program: program::Program,
306 statements_locations: Vec<Option<LocationId<'db>>>,
308 functions_info: OrderedHashMap<FunctionId, FunctionDebugInfo<'db>>,
310}
311
312fn assemble_program<'db>(
315 db: &dyn Database,
316 functions: Vec<&'db pre_sierra::Function<'db>>,
317 statements: Vec<pre_sierra::StatementWithLocation<'db>>,
318) -> AssembledProgram<'db> {
319 let label_replacer = LabelReplacer::from_statements(&statements);
320 let functions_info = functions
321 .iter()
322 .map(|f| {
323 (
324 f.id.clone(),
325 FunctionDebugInfo {
326 signature_location: f.signature_location,
327 variables_locations: f.variable_locations.iter().cloned().collect(),
328 },
329 )
330 })
331 .collect();
332 let funcs = functions
333 .into_iter()
334 .map(|function| {
335 let sierra_signature = db.get_function_signature(function.id.clone()).unwrap();
336 program::Function::new(
337 function.id.clone(),
338 function.parameters.clone(),
339 sierra_signature.ret_types.clone(),
340 label_replacer.handle_label_id(function.entry_point),
341 )
342 })
343 .collect_vec();
344
345 let libfunc_declarations = collect_and_generate_libfunc_declarations(db, &statements);
346 let type_declarations = generate_type_declarations(db, &libfunc_declarations, &funcs);
347 let (resolved_statements, statements_locations) =
349 resolve_labels_and_extract_locations(statements, &label_replacer);
350 let program = program::Program {
351 type_declarations,
352 libfunc_declarations,
353 statements: resolved_statements,
354 funcs,
355 };
356 AssembledProgram { program, statements_locations, functions_info }
357}
358
359pub fn try_get_function_with_body_id<'db>(
361 db: &'db dyn Database,
362 statement: &pre_sierra::StatementWithLocation<'db>,
363) -> Option<ConcreteFunctionWithBodyId<'db>> {
364 let invc = try_extract_matches!(
365 try_extract_matches!(&statement.statement, pre_sierra::Statement::Sierra)?,
366 program::GenStatement::Invocation
367 )?;
368 let libfunc = db.lookup_concrete_lib_func(&invc.libfunc_id);
369 let inner_function = if libfunc.generic_id == "function_call".into()
370 || libfunc.generic_id == "coupon_call".into()
371 {
372 libfunc.generic_args.first()?.clone()
373 } else if libfunc.generic_id == "coupon_buy".into()
374 || libfunc.generic_id == "coupon_refund".into()
375 {
376 let coupon_ty = try_extract_matches!(
381 libfunc.generic_args.first()?,
382 cairo_lang_sierra::program::GenericArg::Type
383 )?;
384 let coupon_long_id = sierra_concrete_long_id(db, coupon_ty.clone()).unwrap();
385 coupon_long_id.generic_args.first()?.clone()
386 } else {
387 return None;
388 };
389
390 db.lookup_sierra_function(&try_extract_matches!(
391 inner_function,
392 cairo_lang_sierra::program::GenericArg::UserFunc
393 )?)
394 .body(db)
395 .expect("No diagnostics at this stage.")
396}
397
398#[salsa::tracked(returns(ref))]
399pub fn get_sierra_program<'db>(
400 db: &'db dyn Database,
401 _tracked: Tracked,
402 requested_crate_ids: Vec<CrateId<'db>>,
403) -> Maybe<SierraProgramWithDebug<'db>> {
404 let requested_function_ids = find_all_free_function_ids(db, requested_crate_ids)?;
405 db.get_sierra_program_for_functions(requested_function_ids).cloned()
406}
407
408pub fn find_all_free_function_ids<'db>(
410 db: &'db dyn Database,
411 requested_crate_ids: Vec<CrateId<'db>>,
412) -> Maybe<Vec<ConcreteFunctionWithBodyId<'db>>> {
413 let mut requested_function_ids = vec![];
414 for crate_id in requested_crate_ids {
415 for module_id in db.crate_modules(crate_id).iter() {
416 for (free_func_id, _) in module_id.module_data(db)?.free_functions(db).iter() {
417 if let Some(function) =
419 ConcreteFunctionWithBodyId::from_no_generics_free(db, *free_func_id)
420 {
421 requested_function_ids.push(function)
422 }
423 }
424 }
425 }
426 Ok(requested_function_ids)
427}
428
429pub fn get_dummy_program_for_size_estimation(
434 db: &dyn Database,
435 function_id: ConcreteFunctionWithBodyId<'_>,
436) -> Maybe<Program> {
437 let function = db.function_with_body_sierra(function_id)?;
438
439 let mut processed_function_ids =
440 UnorderedHashSet::<ConcreteFunctionWithBodyId<'_>>::from_iter([function_id]);
441
442 let mut functions = vec![function];
443
444 for statement in &function.body {
445 if let Some(function_id) = try_get_function_with_body_id(db, statement) {
446 if processed_function_ids.insert(function_id) {
447 functions.push(db.priv_get_dummy_function(function_id)?);
448 }
449 }
450 }
451 let statements = functions
453 .iter()
454 .flat_map(|f| f.body.iter())
455 .map(|s| s.statement.clone().into_statement_without_location())
456 .collect();
457
458 Ok(assemble_program(db, functions, statements).program)
459}