cairo_lang_sierra_generator/
program_generator.rs1use std::collections::VecDeque;
2use std::sync::Arc;
3
4use cairo_lang_debug::DebugWithDb;
5use cairo_lang_defs::diagnostic_utils::StableLocation;
6use cairo_lang_diagnostics::{Maybe, get_location_marks};
7use cairo_lang_filesystem::ids::CrateId;
8use cairo_lang_lowering::ids::ConcreteFunctionWithBodyId;
9use cairo_lang_sierra::extensions::GenericLibfuncEx;
10use cairo_lang_sierra::extensions::core::CoreLibfunc;
11use cairo_lang_sierra::ids::{ConcreteLibfuncId, ConcreteTypeId};
12use cairo_lang_sierra::program::{self, DeclaredTypeInfo, Program, StatementIdx};
13use cairo_lang_utils::ordered_hash_set::OrderedHashSet;
14use cairo_lang_utils::unordered_hash_set::UnorderedHashSet;
15use cairo_lang_utils::{LookupIntern, try_extract_matches};
16use itertools::{Itertools, chain};
17
18use crate::db::{SierraGenGroup, sierra_concrete_long_id};
19use crate::extra_sierra_info::type_has_const_size;
20use crate::pre_sierra;
21use crate::replace_ids::{DebugReplacer, SierraIdReplacer};
22use crate::resolve_labels::{LabelReplacer, resolve_labels_and_extract_locations};
23use crate::specialization_context::SierraSignatureSpecializationContext;
24use crate::statements_locations::StatementsLocations;
25
26#[cfg(test)]
27#[path = "program_generator_test.rs"]
28mod test;
29
30fn generate_libfunc_declarations<'a>(
33 db: &dyn SierraGenGroup,
34 libfuncs: impl Iterator<Item = &'a ConcreteLibfuncId>,
35) -> Vec<program::LibfuncDeclaration> {
36 libfuncs
37 .into_iter()
38 .map(|libfunc_id| program::LibfuncDeclaration {
39 id: libfunc_id.clone(),
40 long_id: libfunc_id.lookup_intern(db),
41 })
42 .collect()
43}
44
45fn collect_used_libfuncs(
47 statements: &[pre_sierra::StatementWithLocation],
48) -> OrderedHashSet<ConcreteLibfuncId> {
49 statements
50 .iter()
51 .filter_map(|statement| match &statement.statement {
52 pre_sierra::Statement::Sierra(program::GenStatement::Invocation(invocation)) => {
53 Some(invocation.libfunc_id.clone())
54 }
55 pre_sierra::Statement::Sierra(program::GenStatement::Return(_))
56 | pre_sierra::Statement::Label(_) => None,
57 pre_sierra::Statement::PushValues(_) => {
58 panic!("Unexpected pre_sierra::Statement::PushValues in collect_used_libfuncs().")
59 }
60 })
61 .collect()
62}
63
64fn generate_type_declarations(
67 db: &dyn SierraGenGroup,
68 mut remaining_types: OrderedHashSet<ConcreteTypeId>,
69) -> Vec<program::TypeDeclaration> {
70 let mut declarations = vec![];
71 let mut already_declared = UnorderedHashSet::default();
72 while let Some(ty) = remaining_types.iter().next().cloned() {
73 remaining_types.swap_remove(&ty);
74 generate_type_declarations_helper(
75 db,
76 &ty,
77 &mut declarations,
78 &mut remaining_types,
79 &mut already_declared,
80 );
81 }
82 declarations
83}
84
85fn generate_type_declarations_helper(
90 db: &dyn SierraGenGroup,
91 ty: &ConcreteTypeId,
92 declarations: &mut Vec<program::TypeDeclaration>,
93 remaining_types: &mut OrderedHashSet<ConcreteTypeId>,
94 already_declared: &mut UnorderedHashSet<ConcreteTypeId>,
95) {
96 if already_declared.contains(ty) {
97 return;
98 }
99 let long_id = sierra_concrete_long_id(db, ty.clone()).unwrap();
100 already_declared.insert(ty.clone());
101 let inner_tys = long_id
102 .generic_args
103 .iter()
104 .filter_map(|arg| try_extract_matches!(arg, program::GenericArg::Type));
105 if type_has_const_size(&long_id.generic_id) {
108 remaining_types.extend(inner_tys.cloned());
109 } else {
110 for inner_ty in inner_tys {
111 generate_type_declarations_helper(
112 db,
113 inner_ty,
114 declarations,
115 remaining_types,
116 already_declared,
117 );
118 }
119 }
120
121 let type_info = db.get_type_info(ty.clone()).unwrap();
122 declarations.push(program::TypeDeclaration {
123 id: ty.clone(),
124 long_id: long_id.as_ref().clone(),
125 declared_type_info: Some(DeclaredTypeInfo {
126 storable: type_info.storable,
127 droppable: type_info.droppable,
128 duplicatable: type_info.duplicatable,
129 zero_sized: type_info.zero_sized,
130 }),
131 });
132}
133
134fn collect_used_types(
137 db: &dyn SierraGenGroup,
138 libfunc_declarations: &[program::LibfuncDeclaration],
139 functions: &[program::Function],
140) -> OrderedHashSet<ConcreteTypeId> {
141 let types_in_libfuncs = libfunc_declarations.iter().flat_map(|libfunc| {
143 let signature = CoreLibfunc::specialize_signature_by_id(
145 &SierraSignatureSpecializationContext(db),
146 &libfunc.long_id.generic_id,
147 &libfunc.long_id.generic_args,
148 )
149 .unwrap_or_else(|err| panic!("Failed to specialize: `{}`. Error: {err}",
152 DebugReplacer { db }.replace_libfunc_id(&libfunc.id)));
153 chain!(
154 signature.param_signatures.into_iter().map(|param_signature| param_signature.ty),
155 signature.branch_signatures.into_iter().flat_map(|info| info.vars).map(|var| var.ty),
156 libfunc.long_id.generic_args.iter().filter_map(|arg| match arg {
157 program::GenericArg::Type(ty) => Some(ty.clone()),
158 _ => None,
159 })
160 )
161 });
162
163 let types_in_user_functions = functions
171 .iter()
172 .flat_map(|func| chain!(&func.signature.param_types, &func.signature.ret_types).cloned());
173
174 chain!(types_in_libfuncs, types_in_user_functions).collect()
175}
176
177#[derive(Clone, Debug, Eq, PartialEq)]
178pub struct SierraProgramWithDebug {
179 pub program: cairo_lang_sierra::program::Program,
180 pub debug_info: SierraProgramDebugInfo,
181}
182impl DebugWithDb<dyn SierraGenGroup> for SierraProgramWithDebug {
185 fn fmt(&self, f: &mut std::fmt::Formatter<'_>, db: &dyn SierraGenGroup) -> std::fmt::Result {
186 let sierra_program = DebugReplacer { db }.apply(&self.program);
187 for declaration in &sierra_program.type_declarations {
188 writeln!(f, "{declaration};")?;
189 }
190 writeln!(f)?;
191 for declaration in &sierra_program.libfunc_declarations {
192 writeln!(f, "{declaration};")?;
193 }
194 writeln!(f)?;
195 let mut funcs = sierra_program.funcs.iter().peekable();
196 while let Some(func) = funcs.next() {
197 let start = func.entry_point.0;
198 let end = funcs
199 .peek()
200 .map(|f| f.entry_point.0)
201 .unwrap_or_else(|| sierra_program.statements.len());
202 writeln!(f, "// {}:", func.id)?;
203 for param in &func.params {
204 writeln!(f, "// {param}")?;
205 }
206 for i in start..end {
207 writeln!(f, "{}; // {i}", sierra_program.statements[i])?;
208 if let Some(loc) =
209 &self.debug_info.statements_locations.locations.get(&StatementIdx(i))
210 {
211 let loc =
212 get_location_marks(db, &loc.first().unwrap().diagnostic_location(db), true);
213 println!("{}", loc.split('\n').map(|l| format!("// {l}")).join("\n"));
214 }
215 }
216 }
217 writeln!(f)?;
218 for func in &sierra_program.funcs {
219 writeln!(f, "{func};")?;
220 }
221 Ok(())
222 }
223}
224
225#[derive(Clone, Debug, Eq, PartialEq)]
226pub struct SierraProgramDebugInfo {
227 pub statements_locations: StatementsLocations,
228}
229
230pub fn get_sierra_program_for_functions(
231 db: &dyn SierraGenGroup,
232 requested_function_ids: Vec<ConcreteFunctionWithBodyId>,
233) -> Maybe<Arc<SierraProgramWithDebug>> {
234 let mut functions: Vec<Arc<pre_sierra::Function>> = vec![];
235 let mut statements: Vec<pre_sierra::StatementWithLocation> = vec![];
236 let mut processed_function_ids = UnorderedHashSet::<ConcreteFunctionWithBodyId>::default();
237 let mut function_id_queue: VecDeque<ConcreteFunctionWithBodyId> =
238 requested_function_ids.into_iter().collect();
239 while let Some(function_id) = function_id_queue.pop_front() {
240 if !processed_function_ids.insert(function_id) {
241 continue;
242 }
243 let function: Arc<pre_sierra::Function> = db.function_with_body_sierra(function_id)?;
244 functions.push(function.clone());
245 statements.extend_from_slice(&function.body);
246
247 for statement in &function.body {
248 if let Some(related_function_id) = try_get_function_with_body_id(db, statement) {
249 function_id_queue.push_back(related_function_id);
250 }
251 }
252 }
253
254 let (program, statements_locations) = assemble_program(db, functions, statements);
255 Ok(Arc::new(SierraProgramWithDebug {
256 program,
257 debug_info: SierraProgramDebugInfo {
258 statements_locations: StatementsLocations::from_locations_vec(&statements_locations),
259 },
260 }))
261}
262
263pub fn assemble_program(
266 db: &dyn SierraGenGroup,
267 functions: Vec<Arc<pre_sierra::Function>>,
268 statements: Vec<pre_sierra::StatementWithLocation>,
269) -> (program::Program, Vec<Vec<StableLocation>>) {
270 let label_replacer = LabelReplacer::from_statements(&statements);
271 let funcs = functions
272 .into_iter()
273 .map(|function| {
274 let sierra_signature = db.get_function_signature(function.id.clone()).unwrap();
275 program::Function::new(
276 function.id.clone(),
277 function.parameters.clone(),
278 sierra_signature.ret_types.clone(),
279 label_replacer.handle_label_id(function.entry_point),
280 )
281 })
282 .collect_vec();
283
284 let libfunc_declarations =
285 generate_libfunc_declarations(db, collect_used_libfuncs(&statements).iter());
286 let type_declarations =
287 generate_type_declarations(db, collect_used_types(db, &libfunc_declarations, &funcs));
288 let (resolved_statements, statements_locations) =
290 resolve_labels_and_extract_locations(statements, &label_replacer);
291 let program = program::Program {
292 type_declarations,
293 libfunc_declarations,
294 statements: resolved_statements,
295 funcs,
296 };
297 (program, statements_locations)
298}
299
300pub fn try_get_function_with_body_id(
302 db: &dyn SierraGenGroup,
303 statement: &pre_sierra::StatementWithLocation,
304) -> Option<ConcreteFunctionWithBodyId> {
305 let invc = try_extract_matches!(
306 try_extract_matches!(&statement.statement, pre_sierra::Statement::Sierra)?,
307 program::GenStatement::Invocation
308 )?;
309 let libfunc = invc.libfunc_id.lookup_intern(db);
310 let inner_function = if libfunc.generic_id == "function_call".into()
311 || libfunc.generic_id == "coupon_call".into()
312 {
313 libfunc.generic_args.first()?.clone()
314 } else if libfunc.generic_id == "coupon_buy".into()
315 || libfunc.generic_id == "coupon_refund".into()
316 {
317 let coupon_ty = try_extract_matches!(
322 libfunc.generic_args.first()?,
323 cairo_lang_sierra::program::GenericArg::Type
324 )?;
325 let coupon_long_id = sierra_concrete_long_id(db, coupon_ty.clone()).unwrap();
326 coupon_long_id.generic_args.first()?.clone()
327 } else {
328 return None;
329 };
330
331 try_extract_matches!(inner_function, cairo_lang_sierra::program::GenericArg::UserFunc)?
332 .lookup_intern(db)
333 .body(db)
334 .expect("No diagnostics at this stage.")
335}
336
337pub fn get_sierra_program(
338 db: &dyn SierraGenGroup,
339 requested_crate_ids: Vec<CrateId>,
340) -> Maybe<Arc<SierraProgramWithDebug>> {
341 let mut requested_function_ids = vec![];
342 for crate_id in requested_crate_ids {
343 for module_id in db.crate_modules(crate_id).iter() {
344 for (free_func_id, _) in db.module_free_functions(*module_id)?.iter() {
345 if let Some(function) =
347 ConcreteFunctionWithBodyId::from_no_generics_free(db, *free_func_id)
348 {
349 requested_function_ids.push(function)
350 }
351 }
352 }
353 }
354 db.get_sierra_program_for_functions(requested_function_ids)
355}
356
357pub fn get_dummy_program_for_size_estimation(
362 db: &dyn SierraGenGroup,
363 function_id: ConcreteFunctionWithBodyId,
364) -> Maybe<Program> {
365 let function: Arc<pre_sierra::Function> = db.function_with_body_sierra(function_id)?;
366
367 let mut processed_function_ids =
368 UnorderedHashSet::<ConcreteFunctionWithBodyId>::from_iter([function_id]);
369
370 let mut functions = vec![function.clone()];
371
372 for statement in &function.body {
373 if let Some(function_id) = try_get_function_with_body_id(db, statement) {
374 if processed_function_ids.insert(function_id) {
375 functions.push(db.priv_get_dummy_function(function_id)?);
376 }
377 }
378 }
379 let statements = functions
381 .iter()
382 .flat_map(|f| f.body.iter())
383 .map(|s| s.statement.clone().into_statement_without_location())
384 .collect();
385
386 let (program, _statements_locations) = assemble_program(db, functions, statements);
387
388 Ok(program)
389}