cairo_lang_sierra_generator/
program_generator.rs1use std::collections::VecDeque;
2
3use cairo_lang_debug::DebugWithDb;
4use cairo_lang_defs::db::DefsGroup;
5use cairo_lang_defs::diagnostic_utils::StableLocation;
6use cairo_lang_diagnostics::{Maybe, get_location_marks};
7use cairo_lang_filesystem::ids::{CrateId, Tracked};
8use cairo_lang_lowering::ids::ConcreteFunctionWithBodyId;
9use cairo_lang_sierra::extensions::GenericLibfuncEx;
10use cairo_lang_sierra::extensions::core::CoreLibfunc;
11use cairo_lang_sierra::ids::{ConcreteLibfuncId, ConcreteTypeId};
12use cairo_lang_sierra::program::{self, DeclaredTypeInfo, Program, StatementIdx};
13use cairo_lang_utils::ordered_hash_set::OrderedHashSet;
14use cairo_lang_utils::try_extract_matches;
15use cairo_lang_utils::unordered_hash_set::UnorderedHashSet;
16use itertools::{Itertools, chain};
17use salsa::Database;
18
19use crate::db::{SierraGenGroup, sierra_concrete_long_id};
20use crate::extra_sierra_info::type_has_const_size;
21use crate::pre_sierra;
22use crate::replace_ids::{DebugReplacer, SierraIdReplacer};
23use crate::resolve_labels::{LabelReplacer, resolve_labels_and_extract_locations};
24use crate::specialization_context::SierraSignatureSpecializationContext;
25use crate::statements_locations::StatementsLocations;
26
27#[cfg(test)]
28#[path = "program_generator_test.rs"]
29mod test;
30
31fn collect_and_generate_libfunc_declarations<'db>(
34 db: &dyn Database,
35 statements: &[pre_sierra::StatementWithLocation<'db>],
36) -> Vec<program::LibfuncDeclaration> {
37 let mut declared_libfuncs = UnorderedHashSet::<ConcreteLibfuncId>::default();
38 statements
39 .iter()
40 .filter_map(|statement| match &statement.statement {
41 pre_sierra::Statement::Sierra(program::GenStatement::Invocation(invocation)) => {
42 declared_libfuncs.insert(invocation.libfunc_id.clone()).then(|| {
43 program::LibfuncDeclaration {
44 id: invocation.libfunc_id.clone(),
45 long_id: db.lookup_concrete_lib_func(&invocation.libfunc_id),
46 }
47 })
48 }
49 pre_sierra::Statement::Sierra(program::GenStatement::Return(_))
50 | pre_sierra::Statement::Label(_) => None,
51 pre_sierra::Statement::PushValues(_) => {
52 panic!("Unexpected pre_sierra::Statement::PushValues in collect_used_libfuncs().")
53 }
54 })
55 .collect()
56}
57
58fn generate_type_declarations(
61 db: &dyn Database,
62 libfunc_declarations: &[program::LibfuncDeclaration],
63 functions: &[program::Function],
64) -> Vec<program::TypeDeclaration> {
65 let mut declarations = vec![];
66 let mut already_declared = UnorderedHashSet::default();
67 let mut remaining_types = collect_used_types(db, libfunc_declarations, functions);
68 while let Some(ty) = remaining_types.iter().next().cloned() {
69 remaining_types.swap_remove(&ty);
70 generate_type_declarations_helper(
71 db,
72 &ty,
73 &mut declarations,
74 &mut remaining_types,
75 &mut already_declared,
76 );
77 }
78 declarations
79}
80
81fn generate_type_declarations_helper(
86 db: &dyn Database,
87 ty: &ConcreteTypeId,
88 declarations: &mut Vec<program::TypeDeclaration>,
89 remaining_types: &mut OrderedHashSet<ConcreteTypeId>,
90 already_declared: &mut UnorderedHashSet<ConcreteTypeId>,
91) {
92 if already_declared.contains(ty) {
93 return;
94 }
95 let long_id = sierra_concrete_long_id(db, ty.clone()).unwrap();
96 already_declared.insert(ty.clone());
97 let inner_tys = long_id
98 .generic_args
99 .iter()
100 .filter_map(|arg| try_extract_matches!(arg, program::GenericArg::Type));
101 if type_has_const_size(&long_id.generic_id) {
104 remaining_types.extend(inner_tys.cloned());
105 } else {
106 for inner_ty in inner_tys {
107 generate_type_declarations_helper(
108 db,
109 inner_ty,
110 declarations,
111 remaining_types,
112 already_declared,
113 );
114 }
115 }
116
117 let type_info = db.get_type_info(ty.clone()).unwrap();
118 declarations.push(program::TypeDeclaration {
119 id: ty.clone(),
120 long_id: long_id.as_ref().clone(),
121 declared_type_info: Some(DeclaredTypeInfo {
122 storable: type_info.storable,
123 droppable: type_info.droppable,
124 duplicatable: type_info.duplicatable,
125 zero_sized: type_info.zero_sized,
126 }),
127 });
128}
129
130fn collect_used_types(
133 db: &dyn Database,
134 libfunc_declarations: &[program::LibfuncDeclaration],
135 functions: &[program::Function],
136) -> OrderedHashSet<ConcreteTypeId> {
137 let mut all_types = OrderedHashSet::default();
138 for libfunc in libfunc_declarations {
140 let types = db.priv_libfunc_dependencies(libfunc.id.clone());
141 all_types.extend(types.iter().cloned());
142 }
143
144 all_types.extend(
152 functions.iter().flat_map(|func| {
153 chain!(&func.signature.param_types, &func.signature.ret_types).cloned()
154 }),
155 );
156 all_types
157}
158
159#[salsa::tracked(returns(ref))]
161pub fn priv_libfunc_dependencies(
162 db: &dyn Database,
163 _tracked: Tracked,
164 libfunc_id: ConcreteLibfuncId,
165) -> Vec<ConcreteTypeId> {
166 let long_id = db.lookup_concrete_lib_func(&libfunc_id);
167 let signature = CoreLibfunc::specialize_signature_by_id(
168 &SierraSignatureSpecializationContext(db),
169 &long_id.generic_id,
170 &long_id.generic_args,
171 )
172 .unwrap_or_else(|err| panic!("Failed to specialize: `{}`. Error: {err}",
175 DebugReplacer { db }.replace_libfunc_id(&libfunc_id)));
176 let mut all_types = vec![];
178 let mut add_ty = |ty: ConcreteTypeId| {
179 if !all_types.contains(&ty) {
180 all_types.push(ty);
181 }
182 };
183 for param_signature in signature.param_signatures {
184 add_ty(param_signature.ty);
185 }
186 for info in signature.branch_signatures {
187 for var in info.vars {
188 add_ty(var.ty);
189 }
190 }
191 for arg in long_id.generic_args {
192 if let program::GenericArg::Type(ty) = arg {
193 add_ty(ty);
194 }
195 }
196 all_types
197}
198
199#[derive(Clone, Debug, Eq, PartialEq)]
200pub struct SierraProgramWithDebug<'db> {
201 pub program: cairo_lang_sierra::program::Program,
202 pub debug_info: SierraProgramDebugInfo<'db>,
203}
204
205unsafe impl<'db> salsa::Update for SierraProgramWithDebug<'db> {
206 unsafe fn maybe_update(old_pointer: *mut Self, new_value: Self) -> bool {
207 let old_value = unsafe { &mut *old_pointer };
208 if old_value == &new_value {
209 return false;
210 }
211 *old_value = new_value;
212 true
213 }
214}
215impl<'db> DebugWithDb<'db> for SierraProgramWithDebug<'db> {
218 type Db = dyn Database;
219
220 fn fmt(&self, f: &mut std::fmt::Formatter<'_>, db: &dyn Database) -> std::fmt::Result {
221 let sierra_program = DebugReplacer { db }.apply(&self.program);
222 for declaration in &sierra_program.type_declarations {
223 writeln!(f, "{declaration};")?;
224 }
225 writeln!(f)?;
226 for declaration in &sierra_program.libfunc_declarations {
227 writeln!(f, "{declaration};")?;
228 }
229 writeln!(f)?;
230 let mut funcs = sierra_program.funcs.iter().peekable();
231 while let Some(func) = funcs.next() {
232 let start = func.entry_point.0;
233 let end = funcs
234 .peek()
235 .map(|f| f.entry_point.0)
236 .unwrap_or_else(|| sierra_program.statements.len());
237 writeln!(f, "// {}:", func.id)?;
238 for param in &func.params {
239 writeln!(f, "// {param}")?;
240 }
241 for i in start..end {
242 writeln!(f, "{}; // {i}", sierra_program.statements[i])?;
243 if let Some(loc) =
244 &self.debug_info.statements_locations.locations.get(&StatementIdx(i))
245 {
246 let loc =
247 get_location_marks(db, &loc.first().unwrap().diagnostic_location(db), true);
248 println!("{}", loc.split('\n').map(|l| format!("// {l}")).join("\n"));
249 }
250 }
251 }
252 writeln!(f)?;
253 for func in &sierra_program.funcs {
254 writeln!(f, "{func};")?;
255 }
256 Ok(())
257 }
258}
259
260#[derive(Clone, Debug, Eq, PartialEq)]
261pub struct SierraProgramDebugInfo<'db> {
262 pub statements_locations: StatementsLocations<'db>,
263}
264
265#[salsa::tracked(returns(ref))]
266pub fn get_sierra_program_for_functions<'db>(
267 db: &'db dyn Database,
268 _tracked: Tracked,
269 requested_function_ids: Vec<ConcreteFunctionWithBodyId<'db>>,
270) -> Maybe<SierraProgramWithDebug<'db>> {
271 let mut functions: Vec<&'db pre_sierra::Function<'_>> = vec![];
272 let mut statements: Vec<pre_sierra::StatementWithLocation<'_>> = vec![];
273 let mut processed_function_ids = UnorderedHashSet::<ConcreteFunctionWithBodyId<'_>>::default();
274 let mut function_id_queue: VecDeque<ConcreteFunctionWithBodyId<'_>> =
275 requested_function_ids.into_iter().collect();
276 while let Some(function_id) = function_id_queue.pop_front() {
277 if !processed_function_ids.insert(function_id) {
278 continue;
279 }
280 let function = db.function_with_body_sierra(function_id)?;
281 functions.push(function);
282 statements.extend_from_slice(&function.body);
283
284 for statement in &function.body {
285 if let Some(related_function_id) = try_get_function_with_body_id(db, statement) {
286 function_id_queue.push_back(related_function_id);
287 }
288 }
289 }
290
291 let (program, statements_locations) = assemble_program(db, functions, statements);
292 Ok(SierraProgramWithDebug {
293 program,
294 debug_info: SierraProgramDebugInfo {
295 statements_locations: StatementsLocations::from_locations_vec(&statements_locations),
296 },
297 })
298}
299
300pub fn assemble_program<'db>(
303 db: &dyn Database,
304 functions: Vec<&'db pre_sierra::Function<'db>>,
305 statements: Vec<pre_sierra::StatementWithLocation<'db>>,
306) -> (program::Program, Vec<Vec<StableLocation<'db>>>) {
307 let label_replacer = LabelReplacer::from_statements(&statements);
308 let funcs = functions
309 .into_iter()
310 .map(|function| {
311 let sierra_signature = db.get_function_signature(function.id.clone()).unwrap();
312 program::Function::new(
313 function.id.clone(),
314 function.parameters.clone(),
315 sierra_signature.ret_types.clone(),
316 label_replacer.handle_label_id(function.entry_point),
317 )
318 })
319 .collect_vec();
320
321 let libfunc_declarations = collect_and_generate_libfunc_declarations(db, &statements);
322 let type_declarations = generate_type_declarations(db, &libfunc_declarations, &funcs);
323 let (resolved_statements, statements_locations) =
325 resolve_labels_and_extract_locations(statements, &label_replacer);
326 let program = program::Program {
327 type_declarations,
328 libfunc_declarations,
329 statements: resolved_statements,
330 funcs,
331 };
332 (program, statements_locations)
333}
334
335pub fn try_get_function_with_body_id<'db>(
337 db: &'db dyn Database,
338 statement: &pre_sierra::StatementWithLocation<'db>,
339) -> Option<ConcreteFunctionWithBodyId<'db>> {
340 let invc = try_extract_matches!(
341 try_extract_matches!(&statement.statement, pre_sierra::Statement::Sierra)?,
342 program::GenStatement::Invocation
343 )?;
344 let libfunc = db.lookup_concrete_lib_func(&invc.libfunc_id);
345 let inner_function = if libfunc.generic_id == "function_call".into()
346 || libfunc.generic_id == "coupon_call".into()
347 {
348 libfunc.generic_args.first()?.clone()
349 } else if libfunc.generic_id == "coupon_buy".into()
350 || libfunc.generic_id == "coupon_refund".into()
351 {
352 let coupon_ty = try_extract_matches!(
357 libfunc.generic_args.first()?,
358 cairo_lang_sierra::program::GenericArg::Type
359 )?;
360 let coupon_long_id = sierra_concrete_long_id(db, coupon_ty.clone()).unwrap();
361 coupon_long_id.generic_args.first()?.clone()
362 } else {
363 return None;
364 };
365
366 db.lookup_sierra_function(&try_extract_matches!(
367 inner_function,
368 cairo_lang_sierra::program::GenericArg::UserFunc
369 )?)
370 .body(db)
371 .expect("No diagnostics at this stage.")
372}
373
374#[salsa::tracked(returns(ref))]
375pub fn get_sierra_program<'db>(
376 db: &'db dyn Database,
377 _tracked: Tracked,
378 requested_crate_ids: Vec<CrateId<'db>>,
379) -> Maybe<SierraProgramWithDebug<'db>> {
380 let requested_function_ids = find_all_free_function_ids(db, requested_crate_ids)?;
381 db.get_sierra_program_for_functions(requested_function_ids).cloned()
382}
383
384pub fn find_all_free_function_ids<'db>(
386 db: &'db dyn Database,
387 requested_crate_ids: Vec<CrateId<'db>>,
388) -> Maybe<Vec<ConcreteFunctionWithBodyId<'db>>> {
389 let mut requested_function_ids = vec![];
390 for crate_id in requested_crate_ids {
391 for module_id in db.crate_modules(crate_id).iter() {
392 for (free_func_id, _) in module_id.module_data(db)?.free_functions(db).iter() {
393 if let Some(function) =
395 ConcreteFunctionWithBodyId::from_no_generics_free(db, *free_func_id)
396 {
397 requested_function_ids.push(function)
398 }
399 }
400 }
401 }
402 Ok(requested_function_ids)
403}
404
405pub fn get_dummy_program_for_size_estimation(
410 db: &dyn Database,
411 function_id: ConcreteFunctionWithBodyId<'_>,
412) -> Maybe<Program> {
413 let function = db.function_with_body_sierra(function_id)?;
414
415 let mut processed_function_ids =
416 UnorderedHashSet::<ConcreteFunctionWithBodyId<'_>>::from_iter([function_id]);
417
418 let mut functions = vec![function];
419
420 for statement in &function.body {
421 if let Some(function_id) = try_get_function_with_body_id(db, statement) {
422 if processed_function_ids.insert(function_id) {
423 functions.push(db.priv_get_dummy_function(function_id)?);
424 }
425 }
426 }
427 let statements = functions
429 .iter()
430 .flat_map(|f| f.body.iter())
431 .map(|s| s.statement.clone().into_statement_without_location())
432 .collect();
433
434 let (program, _statements_locations) = assemble_program(db, functions, statements);
435
436 Ok(program)
437}